[XLA:CPU] Port XLA:CPU JIT to OrcV2
OrcV1 is going away in 6154c4115c
PiperOrigin-RevId: 338071271
Change-Id: I429a09c39f62c76974482092a3ba14163bf74802
This commit is contained in:
parent
5fe5a49092
commit
4f7ef4ecb4
@ -552,6 +552,7 @@ cc_library(
|
|||||||
"@llvm-project//llvm:IPO",
|
"@llvm-project//llvm:IPO",
|
||||||
"@llvm-project//llvm:MC",
|
"@llvm-project//llvm:MC",
|
||||||
"@llvm-project//llvm:Object",
|
"@llvm-project//llvm:Object",
|
||||||
|
"@llvm-project//llvm:OrcJIT",
|
||||||
"@llvm-project//llvm:Support",
|
"@llvm-project//llvm:Support",
|
||||||
"@llvm-project//llvm:Target",
|
"@llvm-project//llvm:Target",
|
||||||
],
|
],
|
||||||
|
@ -80,8 +80,8 @@ class FilteredPassManager : public llvm::legacy::PassManager {
|
|||||||
};
|
};
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
std::unique_ptr<llvm::MemoryBuffer> CompilerFunctor::operator()(
|
llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>> CompilerFunctor::operator()(
|
||||||
llvm::Module& module) const {
|
llvm::Module& module) {
|
||||||
FilteredPassManager module_passes(disable_expensive_passes_);
|
FilteredPassManager module_passes(disable_expensive_passes_);
|
||||||
llvm::legacy::FunctionPassManager function_passes(&module);
|
llvm::legacy::FunctionPassManager function_passes(&module);
|
||||||
|
|
||||||
@ -155,7 +155,7 @@ std::unique_ptr<llvm::MemoryBuffer> CompilerFunctor::operator()(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return memory_buffer;
|
return std::move(memory_buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::vector<llvm::VecDesc> VectorFunctionsForTargetLibraryInfoImpl() {
|
static std::vector<llvm::VecDesc> VectorFunctionsForTargetLibraryInfoImpl() {
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_COMPILER_FUNCTOR_H_
|
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_COMPILER_FUNCTOR_H_
|
||||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_COMPILER_FUNCTOR_H_
|
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_COMPILER_FUNCTOR_H_
|
||||||
|
|
||||||
|
#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h"
|
||||||
#include "llvm/IR/LegacyPassManager.h"
|
#include "llvm/IR/LegacyPassManager.h"
|
||||||
#include "llvm/IR/Module.h"
|
#include "llvm/IR/Module.h"
|
||||||
#include "llvm/IR/Operator.h"
|
#include "llvm/IR/Operator.h"
|
||||||
@ -29,7 +30,7 @@ namespace cpu {
|
|||||||
|
|
||||||
// Functor class for compiling an LLVM module down to an object file. For use by
|
// Functor class for compiling an LLVM module down to an object file. For use by
|
||||||
// Orc JIT compile layer.
|
// Orc JIT compile layer.
|
||||||
class CompilerFunctor {
|
class CompilerFunctor : public llvm::orc::IRCompileLayer::IRCompiler {
|
||||||
public:
|
public:
|
||||||
explicit CompilerFunctor(
|
explicit CompilerFunctor(
|
||||||
llvm::TargetMachine* target_machine, int opt_level,
|
llvm::TargetMachine* target_machine, int opt_level,
|
||||||
@ -39,7 +40,8 @@ class CompilerFunctor {
|
|||||||
LLVMCompiler::ModuleHook post_optimization_hook = nullptr,
|
LLVMCompiler::ModuleHook post_optimization_hook = nullptr,
|
||||||
std::function<void(const llvm::object::ObjectFile&)> post_codegen_hook =
|
std::function<void(const llvm::object::ObjectFile&)> post_codegen_hook =
|
||||||
nullptr)
|
nullptr)
|
||||||
: target_machine_(target_machine),
|
: IRCompiler(llvm::orc::IRSymbolMapper::ManglingOptions()),
|
||||||
|
target_machine_(target_machine),
|
||||||
opt_level_(opt_level),
|
opt_level_(opt_level),
|
||||||
optimize_for_size_(optimize_for_size),
|
optimize_for_size_(optimize_for_size),
|
||||||
disable_expensive_passes_(disable_expensive_passes),
|
disable_expensive_passes_(disable_expensive_passes),
|
||||||
@ -49,8 +51,8 @@ class CompilerFunctor {
|
|||||||
post_codegen_hook_(std::move(post_codegen_hook)) {}
|
post_codegen_hook_(std::move(post_codegen_hook)) {}
|
||||||
|
|
||||||
// Compile a Module to an ObjectFile.
|
// Compile a Module to an ObjectFile.
|
||||||
std::unique_ptr<llvm::MemoryBuffer> operator()(
|
llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>> operator()(
|
||||||
llvm::Module& module) const; // NOLINT
|
llvm::Module& module) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Populates the given pass manager with TargetLibraryInfo and
|
// Populates the given pass manager with TargetLibraryInfo and
|
||||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||||||
#include <string.h>
|
#include <string.h>
|
||||||
|
|
||||||
#include <map>
|
#include <map>
|
||||||
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
@ -38,6 +39,7 @@ limitations under the License.
|
|||||||
#include "llvm/IR/Verifier.h"
|
#include "llvm/IR/Verifier.h"
|
||||||
#include "llvm/Object/ObjectFile.h"
|
#include "llvm/Object/ObjectFile.h"
|
||||||
#include "llvm/Support/CommandLine.h"
|
#include "llvm/Support/CommandLine.h"
|
||||||
|
#include "llvm/Support/Error.h"
|
||||||
#include "llvm/Support/TargetRegistry.h"
|
#include "llvm/Support/TargetRegistry.h"
|
||||||
#include "llvm/Support/TargetSelect.h"
|
#include "llvm/Support/TargetSelect.h"
|
||||||
#include "llvm/Target/TargetMachine.h"
|
#include "llvm/Target/TargetMachine.h"
|
||||||
@ -645,11 +647,11 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
|
|||||||
// Compile must be thread-safe so create a new LLVM context for the module.
|
// Compile must be thread-safe so create a new LLVM context for the module.
|
||||||
mlir::MLIRContext mlir_context;
|
mlir::MLIRContext mlir_context;
|
||||||
LoadMLIRDialects(mlir_context);
|
LoadMLIRDialects(mlir_context);
|
||||||
llvm::LLVMContext llvm_context;
|
auto llvm_context = std::make_unique<llvm::LLVMContext>();
|
||||||
auto llvm_module =
|
auto llvm_module =
|
||||||
absl::make_unique<llvm::Module>("__compute_module", llvm_context);
|
absl::make_unique<llvm::Module>("__compute_module", *llvm_context);
|
||||||
|
|
||||||
auto jit = absl::make_unique<SimpleOrcJIT>(
|
auto jit = SimpleOrcJIT::Create(
|
||||||
CompilerTargetOptions(module->config()),
|
CompilerTargetOptions(module->config()),
|
||||||
CodeGenOptLevel(module->config()),
|
CodeGenOptLevel(module->config()),
|
||||||
options::OptimizeForSizeRequested(module->config()),
|
options::OptimizeForSizeRequested(module->config()),
|
||||||
@ -657,8 +659,12 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
|
|||||||
llvm_ir::GetCpuFastMathFlags(module->config()), pre_optimization_ir_hook,
|
llvm_ir::GetCpuFastMathFlags(module->config()), pre_optimization_ir_hook,
|
||||||
post_optimization_ir_hook,
|
post_optimization_ir_hook,
|
||||||
OrcJITPostCompilationHook::Create(module.get()));
|
OrcJITPostCompilationHook::Create(module.get()));
|
||||||
llvm_module->setDataLayout(jit->data_layout());
|
if (!jit) {
|
||||||
llvm_module->setTargetTriple(jit->target_triple().getTriple());
|
return InternalError("Creating JIT failed: %s",
|
||||||
|
llvm::toString(jit.takeError()));
|
||||||
|
}
|
||||||
|
llvm_module->setDataLayout((*jit)->data_layout());
|
||||||
|
llvm_module->setTargetTriple((*jit)->target_triple().getTriple());
|
||||||
|
|
||||||
HloComputation* entry_computation = module->entry_computation();
|
HloComputation* entry_computation = module->entry_computation();
|
||||||
std::unordered_map<const HloInstruction*, int64> instruction_to_profile_idx;
|
std::unordered_map<const HloInstruction*, int64> instruction_to_profile_idx;
|
||||||
@ -700,7 +706,7 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
|
|||||||
// GetEmbeddedComputations guarantees that a called computation occurs
|
// GetEmbeddedComputations guarantees that a called computation occurs
|
||||||
// before a caller computation.
|
// before a caller computation.
|
||||||
|
|
||||||
LLVMTargetMachineFeatures target_machine_features(jit->target_machine());
|
LLVMTargetMachineFeatures target_machine_features((*jit)->target_machine());
|
||||||
IrEmitter ir_emitter(&mlir_context, *module, *assignment, llvm_module.get(),
|
IrEmitter ir_emitter(&mlir_context, *module, *assignment, llvm_module.get(),
|
||||||
std::move(instruction_to_profile_idx),
|
std::move(instruction_to_profile_idx),
|
||||||
std::move(computation_to_profile_idx),
|
std::move(computation_to_profile_idx),
|
||||||
@ -739,7 +745,7 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
|
|||||||
string function_name = [&]() {
|
string function_name = [&]() {
|
||||||
llvm::SmallVector<char, 40> function_name_vector;
|
llvm::SmallVector<char, 40> function_name_vector;
|
||||||
llvm::Mangler::getNameWithPrefix(
|
llvm::Mangler::getNameWithPrefix(
|
||||||
function_name_vector, entry_function->getName(), jit->data_layout());
|
function_name_vector, entry_function->getName(), (*jit)->data_layout());
|
||||||
return string(function_name_vector.begin(), function_name_vector.end());
|
return string(function_name_vector.begin(), function_name_vector.end());
|
||||||
}();
|
}();
|
||||||
|
|
||||||
@ -751,9 +757,11 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
|
|||||||
TF_RETURN_IF_ERROR(VerifyLlvmModule(*llvm_module));
|
TF_RETURN_IF_ERROR(VerifyLlvmModule(*llvm_module));
|
||||||
|
|
||||||
// JIT compile the LLVM IR module to in-memory machine code.
|
// JIT compile the LLVM IR module to in-memory machine code.
|
||||||
jit->AddModule(std::move(llvm_module));
|
llvm::orc::ThreadSafeModule thread_safe_module(std::move(llvm_module),
|
||||||
|
std::move(llvm_context));
|
||||||
|
cantFail((*jit)->AddModule(std::move(thread_safe_module)));
|
||||||
cpu_executable.reset(new CpuExecutable(
|
cpu_executable.reset(new CpuExecutable(
|
||||||
std::move(jit), std::move(assignment), std::move(module), function_name,
|
std::move(*jit), std::move(assignment), std::move(module), function_name,
|
||||||
std::move(hlo_profile_printer_data), std::move(hlo_profile_index_map)));
|
std::move(hlo_profile_printer_data), std::move(hlo_profile_index_map)));
|
||||||
|
|
||||||
if (embed_ir_in_executable) {
|
if (embed_ir_in_executable) {
|
||||||
@ -971,7 +979,7 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
|
|||||||
llvm_ir::GetCpuFastMathFlags(module->config()),
|
llvm_ir::GetCpuFastMathFlags(module->config()),
|
||||||
pre_optimization_ir_hook, post_optimization_ir_hook, post_codegen_hook);
|
pre_optimization_ir_hook, post_optimization_ir_hook, post_codegen_hook);
|
||||||
std::unique_ptr<llvm::MemoryBuffer> object_file =
|
std::unique_ptr<llvm::MemoryBuffer> object_file =
|
||||||
compiler_functor(llvm_module);
|
cantFail(compiler_functor(llvm_module));
|
||||||
ObjectFileData object_file_data(object_file->getBufferStart(),
|
ObjectFileData object_file_data(object_file->getBufferStart(),
|
||||||
object_file->getBufferEnd());
|
object_file->getBufferEnd());
|
||||||
|
|
||||||
|
@ -63,14 +63,14 @@ CpuExecutable::CpuExecutable(
|
|||||||
assignment_(std::move(assignment)) {
|
assignment_(std::move(assignment)) {
|
||||||
// Resolve symbols in the constructor rather than at execution time to avoid
|
// Resolve symbols in the constructor rather than at execution time to avoid
|
||||||
// races because FindSymbol is not thread safe.
|
// races because FindSymbol is not thread safe.
|
||||||
llvm::JITSymbol sym = jit_->FindCompiledSymbol(entry_function_name);
|
llvm::Expected<llvm::JITEvaluatedSymbol> sym =
|
||||||
|
jit_->FindCompiledSymbol(entry_function_name);
|
||||||
// We expect to find the symbol provided with entry_function_name; otherwise
|
// We expect to find the symbol provided with entry_function_name; otherwise
|
||||||
// this is an internal error.
|
// this is an internal error.
|
||||||
CHECK(sym) << "Symbol " << entry_function_name << " not found.";
|
CHECK(sym) << "Symbol " << entry_function_name << " not found.";
|
||||||
// getAddress can do work under the hood in the jit, so it needs to be
|
// getAddress can do work under the hood in the jit, so it needs to be
|
||||||
// guarded by the mutex.
|
// guarded by the mutex.
|
||||||
compute_function_ =
|
compute_function_ = reinterpret_cast<ComputeFunctionType>(sym->getAddress());
|
||||||
reinterpret_cast<ComputeFunctionType>(cantFail(sym.getAddress()));
|
|
||||||
VLOG(1) << "compute_function_ at address "
|
VLOG(1) << "compute_function_ at address "
|
||||||
<< reinterpret_cast<void*>(compute_function_);
|
<< reinterpret_cast<void*>(compute_function_);
|
||||||
}
|
}
|
||||||
|
@ -85,6 +85,8 @@ SimpleOrcJIT::InferTargetMachineForJIT(
|
|||||||
}
|
}
|
||||||
|
|
||||||
SimpleOrcJIT::SimpleOrcJIT(
|
SimpleOrcJIT::SimpleOrcJIT(
|
||||||
|
std::unique_ptr<llvm::orc::TargetProcessControl> target_process_control,
|
||||||
|
std::unique_ptr<llvm::orc::ExecutionSession> execution_session,
|
||||||
const llvm::TargetOptions& target_options,
|
const llvm::TargetOptions& target_options,
|
||||||
llvm::CodeGenOpt::Level opt_level, bool optimize_for_size,
|
llvm::CodeGenOpt::Level opt_level, bool optimize_for_size,
|
||||||
bool disable_expensive_passes, llvm::FastMathFlags fast_math_flags,
|
bool disable_expensive_passes, llvm::FastMathFlags fast_math_flags,
|
||||||
@ -93,48 +95,78 @@ SimpleOrcJIT::SimpleOrcJIT(
|
|||||||
std::function<void(const llvm::object::ObjectFile&)> post_codegen_hook)
|
std::function<void(const llvm::object::ObjectFile&)> post_codegen_hook)
|
||||||
: target_machine_(InferTargetMachineForJIT(target_options, opt_level)),
|
: target_machine_(InferTargetMachineForJIT(target_options, opt_level)),
|
||||||
data_layout_(target_machine_->createDataLayout()),
|
data_layout_(target_machine_->createDataLayout()),
|
||||||
symbol_resolver_(llvm::orc::createLegacyLookupResolver(
|
target_process_control_(std::move(target_process_control)),
|
||||||
execution_session_,
|
execution_session_(std::move(execution_session)),
|
||||||
[this](llvm::StringRef name) -> llvm::JITSymbol {
|
object_layer_(*execution_session_,
|
||||||
return this->ResolveRuntimeSymbol(std::string(name));
|
[]() {
|
||||||
},
|
return std::make_unique<llvm::SectionMemoryManager>(
|
||||||
[](llvm::Error Err) {
|
|
||||||
cantFail(std::move(Err), "lookupFlags failed");
|
|
||||||
})),
|
|
||||||
object_layer_(
|
|
||||||
execution_session_,
|
|
||||||
[this](llvm::orc::VModuleKey) {
|
|
||||||
llvm::orc::LegacyRTDyldObjectLinkingLayer::Resources result;
|
|
||||||
result.MemMgr = std::make_shared<llvm::SectionMemoryManager>(
|
|
||||||
orc_jit_memory_mapper::GetInstance());
|
orc_jit_memory_mapper::GetInstance());
|
||||||
result.Resolver = symbol_resolver_;
|
|
||||||
return result;
|
|
||||||
},
|
|
||||||
/*NotifyLoaded=*/
|
|
||||||
llvm::orc::LegacyRTDyldObjectLinkingLayer::NotifyLoadedFtor(),
|
|
||||||
/*NotifyFinalized=*/
|
|
||||||
[this](VModuleKeyT, const llvm::object::ObjectFile& object,
|
|
||||||
const llvm::RuntimeDyld::LoadedObjectInfo& object_info) {
|
|
||||||
this->NotifyObjectFinalized(object, object_info);
|
|
||||||
},
|
|
||||||
/*NotifyFreed=*/
|
|
||||||
[this](VModuleKeyT, const llvm::object::ObjectFile& object) {
|
|
||||||
this->NotifyObjectFreed(object);
|
|
||||||
}),
|
}),
|
||||||
compile_layer_(
|
compile_layer_(
|
||||||
object_layer_,
|
*execution_session_, object_layer_,
|
||||||
CompilerFunctor(target_machine_.get(), opt_level, optimize_for_size,
|
std::make_unique<CompilerFunctor>(
|
||||||
|
target_machine_.get(), opt_level, optimize_for_size,
|
||||||
disable_expensive_passes, fast_math_flags,
|
disable_expensive_passes, fast_math_flags,
|
||||||
std::move(pre_optimization_hook),
|
std::move(pre_optimization_hook),
|
||||||
std::move(post_optimization_hook),
|
std::move(post_optimization_hook), std::move(post_codegen_hook))),
|
||||||
std::move(post_codegen_hook))),
|
main_jit_dylib_(&execution_session_->createBareJITDylib("<main>")),
|
||||||
gdb_jit_event_listener_(
|
gdb_jit_event_listener_(
|
||||||
llvm::JITEventListener::createGDBRegistrationListener()) {
|
llvm::JITEventListener::createGDBRegistrationListener()) {
|
||||||
VLOG(1) << "CPU target: " << target_machine_->getTargetCPU().str()
|
VLOG(1) << "CPU target: " << target_machine_->getTargetCPU().str()
|
||||||
<< " features: " << target_machine_->getTargetFeatureString().str();
|
<< " features: " << target_machine_->getTargetFeatureString().str();
|
||||||
|
|
||||||
|
// Materialize unknown symbols from the runtime symbol table.
|
||||||
|
class RuntimeSymbolGenerator
|
||||||
|
: public llvm::orc::JITDylib::DefinitionGenerator {
|
||||||
|
SimpleOrcJIT& jit_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
explicit RuntimeSymbolGenerator(SimpleOrcJIT& jit) : jit_(jit) {}
|
||||||
|
llvm::Error tryToGenerate(
|
||||||
|
llvm::orc::LookupKind, llvm::orc::JITDylib& jit_dylib,
|
||||||
|
llvm::orc::JITDylibLookupFlags,
|
||||||
|
const llvm::orc::SymbolLookupSet& names) override {
|
||||||
|
llvm::orc::SymbolMap new_defs;
|
||||||
|
|
||||||
|
for (const auto& kv : names) {
|
||||||
|
const auto& name = kv.first;
|
||||||
|
if (llvm::JITEvaluatedSymbol symbol =
|
||||||
|
jit_.ResolveRuntimeSymbol(*name)) {
|
||||||
|
new_defs[name] = symbol;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cantFail(jit_dylib.define(absoluteSymbols(std::move(new_defs))));
|
||||||
|
return llvm::Error::success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
main_jit_dylib_->addGenerator(
|
||||||
|
std::make_unique<RuntimeSymbolGenerator>(*this));
|
||||||
|
object_layer_.registerJITEventListener(*this);
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm::JITSymbol SimpleOrcJIT::ResolveRuntimeSymbol(const std::string& name) {
|
llvm::Expected<std::unique_ptr<SimpleOrcJIT>> SimpleOrcJIT::Create(
|
||||||
|
const llvm::TargetOptions& target_options,
|
||||||
|
llvm::CodeGenOpt::Level opt_level, bool optimize_for_size,
|
||||||
|
bool disable_expensive_passes, llvm::FastMathFlags fast_math_flags,
|
||||||
|
LLVMCompiler::ModuleHook pre_optimization_hook,
|
||||||
|
LLVMCompiler::ModuleHook post_optimization_hook,
|
||||||
|
std::function<void(const llvm::object::ObjectFile&)> post_codegen_hook) {
|
||||||
|
auto target_process_control = llvm::orc::SelfTargetProcessControl::Create();
|
||||||
|
if (!target_process_control) {
|
||||||
|
return target_process_control.takeError();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto execution_session = std::make_unique<llvm::orc::ExecutionSession>();
|
||||||
|
return std::make_unique<SimpleOrcJIT>(
|
||||||
|
std::move(*target_process_control), std::move(execution_session),
|
||||||
|
target_options, opt_level, optimize_for_size, disable_expensive_passes,
|
||||||
|
fast_math_flags, std::move(pre_optimization_hook),
|
||||||
|
std::move(post_optimization_hook), std::move(post_codegen_hook));
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::JITEvaluatedSymbol SimpleOrcJIT::ResolveRuntimeSymbol(
|
||||||
|
llvm::StringRef name) {
|
||||||
void* func_addr = nullptr;
|
void* func_addr = nullptr;
|
||||||
if (name.size() > 1 && name.front() == data_layout_.getGlobalPrefix()) {
|
if (name.size() > 1 && name.front() == data_layout_.getGlobalPrefix()) {
|
||||||
// On Mac OS X, 'name' may have a leading underscore prefix, even though the
|
// On Mac OS X, 'name' may have a leading underscore prefix, even though the
|
||||||
@ -143,12 +175,13 @@ llvm::JITSymbol SimpleOrcJIT::ResolveRuntimeSymbol(const std::string& name) {
|
|||||||
func_addr =
|
func_addr =
|
||||||
xla::CustomCallTargetRegistry::Global()->Lookup(stripped_name, "Host");
|
xla::CustomCallTargetRegistry::Global()->Lookup(stripped_name, "Host");
|
||||||
} else {
|
} else {
|
||||||
func_addr = xla::CustomCallTargetRegistry::Global()->Lookup(name, "Host");
|
func_addr =
|
||||||
|
xla::CustomCallTargetRegistry::Global()->Lookup(name.str(), "Host");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (func_addr == nullptr) {
|
if (func_addr == nullptr) {
|
||||||
LOG(ERROR)
|
LOG(ERROR)
|
||||||
<< "Unable to resolve runtime symbol: `" << name
|
<< "Unable to resolve runtime symbol: `" << name.str()
|
||||||
<< "'. Hint: if the symbol a custom call target, make sure you've "
|
<< "'. Hint: if the symbol a custom call target, make sure you've "
|
||||||
"registered it with the JIT using "
|
"registered it with the JIT using "
|
||||||
"XLA_CPU_REGISTER_CUSTOM_CALL_TARGET.";
|
"XLA_CPU_REGISTER_CUSTOM_CALL_TARGET.";
|
||||||
@ -159,60 +192,25 @@ llvm::JITSymbol SimpleOrcJIT::ResolveRuntimeSymbol(const std::string& name) {
|
|||||||
return symbol_info;
|
return symbol_info;
|
||||||
}
|
}
|
||||||
|
|
||||||
void SimpleOrcJIT::NotifyObjectFinalized(
|
void SimpleOrcJIT::notifyObjectLoaded(
|
||||||
|
llvm::JITEventListener::ObjectKey key,
|
||||||
const llvm::object::ObjectFile& object,
|
const llvm::object::ObjectFile& object,
|
||||||
const llvm::RuntimeDyld::LoadedObjectInfo& object_info) {
|
const llvm::RuntimeDyld::LoadedObjectInfo& object_info) {
|
||||||
uint64_t key = static_cast<uint64_t>(
|
|
||||||
reinterpret_cast<uintptr_t>(object.getData().data()));
|
|
||||||
gdb_jit_event_listener_->notifyObjectLoaded(key, object, object_info);
|
gdb_jit_event_listener_->notifyObjectLoaded(key, object, object_info);
|
||||||
size_of_generated_code_in_bytes_ += object.getData().size();
|
size_of_generated_code_in_bytes_ += object.getData().size();
|
||||||
}
|
}
|
||||||
|
|
||||||
void SimpleOrcJIT::NotifyObjectFreed(const llvm::object::ObjectFile& object) {
|
void SimpleOrcJIT::notifyFreeingObject(llvm::JITEventListener::ObjectKey key) {
|
||||||
uint64_t key = static_cast<uint64_t>(
|
|
||||||
reinterpret_cast<uintptr_t>(object.getData().data()));
|
|
||||||
gdb_jit_event_listener_->notifyFreeingObject(key);
|
gdb_jit_event_listener_->notifyFreeingObject(key);
|
||||||
}
|
}
|
||||||
|
|
||||||
SimpleOrcJIT::VModuleKeyT SimpleOrcJIT::AddModule(
|
llvm::Error SimpleOrcJIT::AddModule(llvm::orc::ThreadSafeModule module) {
|
||||||
std::unique_ptr<llvm::Module> module) {
|
return compile_layer_.add(*main_jit_dylib_, std::move(module));
|
||||||
auto key = execution_session_.allocateVModule();
|
|
||||||
cantFail(compile_layer_.addModule(key, std::move(module)));
|
|
||||||
module_keys_.push_back(key);
|
|
||||||
return key;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void SimpleOrcJIT::RemoveModule(SimpleOrcJIT::VModuleKeyT key) {
|
llvm::Expected<llvm::JITEvaluatedSymbol> SimpleOrcJIT::FindCompiledSymbol(
|
||||||
module_keys_.erase(std::remove(module_keys_.begin(), module_keys_.end(), key),
|
const std::string& name) {
|
||||||
module_keys_.end());
|
return execution_session_->lookup({main_jit_dylib_}, name);
|
||||||
cantFail(compile_layer_.removeModule(key));
|
|
||||||
}
|
|
||||||
|
|
||||||
llvm::JITSymbol SimpleOrcJIT::FindCompiledSymbol(const std::string& name) {
|
|
||||||
#ifdef _WIN32
|
|
||||||
// The symbol lookup of ObjectLinkingLayer uses the SymbolRef::SF_Exported
|
|
||||||
// flag to decide whether a symbol will be visible or not, when we call
|
|
||||||
// IRCompileLayer::findSymbolIn with ExportedSymbolsOnly set to true.
|
|
||||||
//
|
|
||||||
// But for Windows COFF objects, this flag is currently never set.
|
|
||||||
// For a potential solution see: https://reviews.llvm.org/rL258665
|
|
||||||
// For now, we allow non-exported symbols on Windows as a workaround.
|
|
||||||
const bool exported_symbols_only = false;
|
|
||||||
#else
|
|
||||||
const bool exported_symbols_only = true;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// Resolve symbol from last module to first, allowing later redefinitions of
|
|
||||||
// symbols shadow earlier ones.
|
|
||||||
for (auto& key :
|
|
||||||
llvm::make_range(module_keys_.rbegin(), module_keys_.rend())) {
|
|
||||||
if (auto symbol =
|
|
||||||
compile_layer_.findSymbolIn(key, name, exported_symbols_only)) {
|
|
||||||
return symbol;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nullptr;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#if defined(PLATFORM_WINDOWS)
|
#if defined(PLATFORM_WINDOWS)
|
||||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
|||||||
#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h"
|
#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h"
|
||||||
#include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h"
|
#include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h"
|
||||||
#include "llvm/ExecutionEngine/Orc/SymbolStringPool.h"
|
#include "llvm/ExecutionEngine/Orc/SymbolStringPool.h"
|
||||||
|
#include "llvm/ExecutionEngine/Orc/TargetProcessControl.h"
|
||||||
#include "llvm/IR/Module.h"
|
#include "llvm/IR/Module.h"
|
||||||
#include "llvm/Target/TargetMachine.h"
|
#include "llvm/Target/TargetMachine.h"
|
||||||
#include "tensorflow/compiler/xla/service/cpu/compiler_functor.h"
|
#include "tensorflow/compiler/xla/service/cpu/compiler_functor.h"
|
||||||
@ -42,13 +43,10 @@ namespace cpu {
|
|||||||
// Supports JIT-ing multiple modules but without cross-module linking.
|
// Supports JIT-ing multiple modules but without cross-module linking.
|
||||||
// Implements eager compilation - the module is lowered to binary as soon as
|
// Implements eager compilation - the module is lowered to binary as soon as
|
||||||
// it's added to the JIT.
|
// it's added to the JIT.
|
||||||
class SimpleOrcJIT {
|
class SimpleOrcJIT : public llvm::JITEventListener {
|
||||||
public:
|
public:
|
||||||
using ObjLayerT = llvm::orc::LegacyRTDyldObjectLinkingLayer;
|
using ObjLayerT = llvm::orc::RTDyldObjectLinkingLayer;
|
||||||
using CompileFtor =
|
using CompileLayerT = llvm::orc::IRCompileLayer;
|
||||||
std::function<llvm::Expected<ObjLayerT::ObjectPtr>(llvm::Module&)>;
|
|
||||||
using CompileLayerT = llvm::orc::LegacyIRCompileLayer<ObjLayerT, CompileFtor>;
|
|
||||||
using VModuleKeyT = llvm::orc::VModuleKey;
|
|
||||||
|
|
||||||
// Create a new JIT, targeting the host architecture.
|
// Create a new JIT, targeting the host architecture.
|
||||||
//
|
//
|
||||||
@ -56,6 +54,16 @@ class SimpleOrcJIT {
|
|||||||
// LLVM IR-level optimizations. post_codegen_hook is invoked after
|
// LLVM IR-level optimizations. post_codegen_hook is invoked after
|
||||||
// compiling to machine code.
|
// compiling to machine code.
|
||||||
SimpleOrcJIT(
|
SimpleOrcJIT(
|
||||||
|
std::unique_ptr<llvm::orc::TargetProcessControl> target_process_control,
|
||||||
|
std::unique_ptr<llvm::orc::ExecutionSession> execution_session,
|
||||||
|
const llvm::TargetOptions& target_options,
|
||||||
|
llvm::CodeGenOpt::Level opt_level, bool optimize_for_size,
|
||||||
|
bool disable_expensive_passes, llvm::FastMathFlags fast_math_flags,
|
||||||
|
LLVMCompiler::ModuleHook pre_optimization_hook,
|
||||||
|
LLVMCompiler::ModuleHook post_optimization_hook,
|
||||||
|
std::function<void(const llvm::object::ObjectFile&)> post_codegen_hook);
|
||||||
|
|
||||||
|
static llvm::Expected<std::unique_ptr<SimpleOrcJIT>> Create(
|
||||||
const llvm::TargetOptions& target_options,
|
const llvm::TargetOptions& target_options,
|
||||||
llvm::CodeGenOpt::Level opt_level, bool optimize_for_size,
|
llvm::CodeGenOpt::Level opt_level, bool optimize_for_size,
|
||||||
bool disable_expensive_passes, llvm::FastMathFlags fast_math_flags,
|
bool disable_expensive_passes, llvm::FastMathFlags fast_math_flags,
|
||||||
@ -69,16 +77,12 @@ class SimpleOrcJIT {
|
|||||||
return target_machine_->getTargetTriple();
|
return target_machine_->getTargetTriple();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add a module to the JIT. Returns an opaque key that can be used to later
|
llvm::Error AddModule(llvm::orc::ThreadSafeModule module);
|
||||||
// remove this module.
|
|
||||||
VModuleKeyT AddModule(std::unique_ptr<llvm::Module> module);
|
|
||||||
|
|
||||||
// Remove a module from the JIT and free the memory associated with it.
|
|
||||||
void RemoveModule(VModuleKeyT key);
|
|
||||||
|
|
||||||
// Get the runtime address of the compiled symbol whose name is given. Returns
|
// Get the runtime address of the compiled symbol whose name is given. Returns
|
||||||
// nullptr if the symbol cannot be found.
|
// nullptr if the symbol cannot be found.
|
||||||
llvm::JITSymbol FindCompiledSymbol(const std::string& name);
|
llvm::Expected<llvm::JITEvaluatedSymbol> FindCompiledSymbol(
|
||||||
|
const std::string& name);
|
||||||
|
|
||||||
llvm::TargetMachine* target_machine() const { return target_machine_.get(); }
|
llvm::TargetMachine* target_machine() const { return target_machine_.get(); }
|
||||||
|
|
||||||
@ -93,20 +97,21 @@ class SimpleOrcJIT {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
llvm::JITSymbol ResolveRuntimeSymbol(const std::string& name);
|
llvm::JITEvaluatedSymbol ResolveRuntimeSymbol(llvm::StringRef name);
|
||||||
|
|
||||||
void NotifyObjectFinalized(
|
void notifyObjectLoaded(
|
||||||
|
llvm::JITEventListener::ObjectKey key,
|
||||||
const llvm::object::ObjectFile& object,
|
const llvm::object::ObjectFile& object,
|
||||||
const llvm::RuntimeDyld::LoadedObjectInfo& object_info);
|
const llvm::RuntimeDyld::LoadedObjectInfo& object_info) override;
|
||||||
void NotifyObjectFreed(const llvm::object::ObjectFile& object);
|
void notifyFreeingObject(llvm::JITEventListener::ObjectKey key) override;
|
||||||
|
|
||||||
std::vector<VModuleKeyT> module_keys_;
|
|
||||||
std::unique_ptr<llvm::TargetMachine> target_machine_;
|
std::unique_ptr<llvm::TargetMachine> target_machine_;
|
||||||
const llvm::DataLayout data_layout_;
|
const llvm::DataLayout data_layout_;
|
||||||
llvm::orc::ExecutionSession execution_session_;
|
std::unique_ptr<llvm::orc::TargetProcessControl> target_process_control_;
|
||||||
std::shared_ptr<llvm::orc::SymbolResolver> symbol_resolver_;
|
std::unique_ptr<llvm::orc::ExecutionSession> execution_session_;
|
||||||
ObjLayerT object_layer_;
|
ObjLayerT object_layer_;
|
||||||
CompileLayerT compile_layer_;
|
CompileLayerT compile_layer_;
|
||||||
|
llvm::orc::JITDylib* main_jit_dylib_;
|
||||||
int64 size_of_generated_code_in_bytes_ = 0;
|
int64 size_of_generated_code_in_bytes_ = 0;
|
||||||
|
|
||||||
// Non owning pointer to a JIT event listener that registers the JIT events
|
// Non owning pointer to a JIT event listener that registers the JIT events
|
||||||
|
Loading…
x
Reference in New Issue
Block a user