[XLA:CPU] Port XLA:CPU JIT to OrcV2

OrcV1 is going away in 6154c4115c

PiperOrigin-RevId: 338071271
Change-Id: I429a09c39f62c76974482092a3ba14163bf74802
This commit is contained in:
Benjamin Kramer 2020-10-20 09:09:56 -07:00 committed by TensorFlower Gardener
parent 5fe5a49092
commit 4f7ef4ecb4
7 changed files with 134 additions and 120 deletions

View File

@ -552,6 +552,7 @@ cc_library(
"@llvm-project//llvm:IPO",
"@llvm-project//llvm:MC",
"@llvm-project//llvm:Object",
"@llvm-project//llvm:OrcJIT",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:Target",
],

View File

@ -80,8 +80,8 @@ class FilteredPassManager : public llvm::legacy::PassManager {
};
} // anonymous namespace
std::unique_ptr<llvm::MemoryBuffer> CompilerFunctor::operator()(
llvm::Module& module) const {
llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>> CompilerFunctor::operator()(
llvm::Module& module) {
FilteredPassManager module_passes(disable_expensive_passes_);
llvm::legacy::FunctionPassManager function_passes(&module);
@ -155,7 +155,7 @@ std::unique_ptr<llvm::MemoryBuffer> CompilerFunctor::operator()(
}
}
return memory_buffer;
return std::move(memory_buffer);
}
static std::vector<llvm::VecDesc> VectorFunctionsForTargetLibraryInfoImpl() {

View File

@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_COMPILER_FUNCTOR_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_COMPILER_FUNCTOR_H_
#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Operator.h"
@ -29,7 +30,7 @@ namespace cpu {
// Functor class for compiling an LLVM module down to an object file. For use by
// Orc JIT compile layer.
class CompilerFunctor {
class CompilerFunctor : public llvm::orc::IRCompileLayer::IRCompiler {
public:
explicit CompilerFunctor(
llvm::TargetMachine* target_machine, int opt_level,
@ -39,7 +40,8 @@ class CompilerFunctor {
LLVMCompiler::ModuleHook post_optimization_hook = nullptr,
std::function<void(const llvm::object::ObjectFile&)> post_codegen_hook =
nullptr)
: target_machine_(target_machine),
: IRCompiler(llvm::orc::IRSymbolMapper::ManglingOptions()),
target_machine_(target_machine),
opt_level_(opt_level),
optimize_for_size_(optimize_for_size),
disable_expensive_passes_(disable_expensive_passes),
@ -49,8 +51,8 @@ class CompilerFunctor {
post_codegen_hook_(std::move(post_codegen_hook)) {}
// Compile a Module to an ObjectFile.
std::unique_ptr<llvm::MemoryBuffer> operator()(
llvm::Module& module) const; // NOLINT
llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>> operator()(
llvm::Module& module) override;
private:
// Populates the given pass manager with TargetLibraryInfo and

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <string.h>
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
@ -38,6 +39,7 @@ limitations under the License.
#include "llvm/IR/Verifier.h"
#include "llvm/Object/ObjectFile.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/TargetRegistry.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Target/TargetMachine.h"
@ -645,11 +647,11 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
// Compile must be thread-safe so create a new LLVM context for the module.
mlir::MLIRContext mlir_context;
LoadMLIRDialects(mlir_context);
llvm::LLVMContext llvm_context;
auto llvm_context = std::make_unique<llvm::LLVMContext>();
auto llvm_module =
absl::make_unique<llvm::Module>("__compute_module", llvm_context);
absl::make_unique<llvm::Module>("__compute_module", *llvm_context);
auto jit = absl::make_unique<SimpleOrcJIT>(
auto jit = SimpleOrcJIT::Create(
CompilerTargetOptions(module->config()),
CodeGenOptLevel(module->config()),
options::OptimizeForSizeRequested(module->config()),
@ -657,8 +659,12 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
llvm_ir::GetCpuFastMathFlags(module->config()), pre_optimization_ir_hook,
post_optimization_ir_hook,
OrcJITPostCompilationHook::Create(module.get()));
llvm_module->setDataLayout(jit->data_layout());
llvm_module->setTargetTriple(jit->target_triple().getTriple());
if (!jit) {
return InternalError("Creating JIT failed: %s",
llvm::toString(jit.takeError()));
}
llvm_module->setDataLayout((*jit)->data_layout());
llvm_module->setTargetTriple((*jit)->target_triple().getTriple());
HloComputation* entry_computation = module->entry_computation();
std::unordered_map<const HloInstruction*, int64> instruction_to_profile_idx;
@ -700,7 +706,7 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
// GetEmbeddedComputations guarantees that a called computation occurs
// before a caller computation.
LLVMTargetMachineFeatures target_machine_features(jit->target_machine());
LLVMTargetMachineFeatures target_machine_features((*jit)->target_machine());
IrEmitter ir_emitter(&mlir_context, *module, *assignment, llvm_module.get(),
std::move(instruction_to_profile_idx),
std::move(computation_to_profile_idx),
@ -739,7 +745,7 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
string function_name = [&]() {
llvm::SmallVector<char, 40> function_name_vector;
llvm::Mangler::getNameWithPrefix(
function_name_vector, entry_function->getName(), jit->data_layout());
function_name_vector, entry_function->getName(), (*jit)->data_layout());
return string(function_name_vector.begin(), function_name_vector.end());
}();
@ -751,9 +757,11 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
TF_RETURN_IF_ERROR(VerifyLlvmModule(*llvm_module));
// JIT compile the LLVM IR module to in-memory machine code.
jit->AddModule(std::move(llvm_module));
llvm::orc::ThreadSafeModule thread_safe_module(std::move(llvm_module),
std::move(llvm_context));
cantFail((*jit)->AddModule(std::move(thread_safe_module)));
cpu_executable.reset(new CpuExecutable(
std::move(jit), std::move(assignment), std::move(module), function_name,
std::move(*jit), std::move(assignment), std::move(module), function_name,
std::move(hlo_profile_printer_data), std::move(hlo_profile_index_map)));
if (embed_ir_in_executable) {
@ -971,7 +979,7 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
llvm_ir::GetCpuFastMathFlags(module->config()),
pre_optimization_ir_hook, post_optimization_ir_hook, post_codegen_hook);
std::unique_ptr<llvm::MemoryBuffer> object_file =
compiler_functor(llvm_module);
cantFail(compiler_functor(llvm_module));
ObjectFileData object_file_data(object_file->getBufferStart(),
object_file->getBufferEnd());

View File

@ -63,14 +63,14 @@ CpuExecutable::CpuExecutable(
assignment_(std::move(assignment)) {
// Resolve symbols in the constructor rather than at execution time to avoid
// races because FindSymbol is not thread safe.
llvm::JITSymbol sym = jit_->FindCompiledSymbol(entry_function_name);
llvm::Expected<llvm::JITEvaluatedSymbol> sym =
jit_->FindCompiledSymbol(entry_function_name);
// We expect to find the symbol provided with entry_function_name; otherwise
// this is an internal error.
CHECK(sym) << "Symbol " << entry_function_name << " not found.";
// getAddress can do work under the hood in the jit, so it needs to be
// guarded by the mutex.
compute_function_ =
reinterpret_cast<ComputeFunctionType>(cantFail(sym.getAddress()));
compute_function_ = reinterpret_cast<ComputeFunctionType>(sym->getAddress());
VLOG(1) << "compute_function_ at address "
<< reinterpret_cast<void*>(compute_function_);
}

View File

@ -85,6 +85,8 @@ SimpleOrcJIT::InferTargetMachineForJIT(
}
SimpleOrcJIT::SimpleOrcJIT(
std::unique_ptr<llvm::orc::TargetProcessControl> target_process_control,
std::unique_ptr<llvm::orc::ExecutionSession> execution_session,
const llvm::TargetOptions& target_options,
llvm::CodeGenOpt::Level opt_level, bool optimize_for_size,
bool disable_expensive_passes, llvm::FastMathFlags fast_math_flags,
@ -93,48 +95,78 @@ SimpleOrcJIT::SimpleOrcJIT(
std::function<void(const llvm::object::ObjectFile&)> post_codegen_hook)
: target_machine_(InferTargetMachineForJIT(target_options, opt_level)),
data_layout_(target_machine_->createDataLayout()),
symbol_resolver_(llvm::orc::createLegacyLookupResolver(
execution_session_,
[this](llvm::StringRef name) -> llvm::JITSymbol {
return this->ResolveRuntimeSymbol(std::string(name));
},
[](llvm::Error Err) {
cantFail(std::move(Err), "lookupFlags failed");
})),
object_layer_(
execution_session_,
[this](llvm::orc::VModuleKey) {
llvm::orc::LegacyRTDyldObjectLinkingLayer::Resources result;
result.MemMgr = std::make_shared<llvm::SectionMemoryManager>(
orc_jit_memory_mapper::GetInstance());
result.Resolver = symbol_resolver_;
return result;
},
/*NotifyLoaded=*/
llvm::orc::LegacyRTDyldObjectLinkingLayer::NotifyLoadedFtor(),
/*NotifyFinalized=*/
[this](VModuleKeyT, const llvm::object::ObjectFile& object,
const llvm::RuntimeDyld::LoadedObjectInfo& object_info) {
this->NotifyObjectFinalized(object, object_info);
},
/*NotifyFreed=*/
[this](VModuleKeyT, const llvm::object::ObjectFile& object) {
this->NotifyObjectFreed(object);
}),
target_process_control_(std::move(target_process_control)),
execution_session_(std::move(execution_session)),
object_layer_(*execution_session_,
[]() {
return std::make_unique<llvm::SectionMemoryManager>(
orc_jit_memory_mapper::GetInstance());
}),
compile_layer_(
object_layer_,
CompilerFunctor(target_machine_.get(), opt_level, optimize_for_size,
disable_expensive_passes, fast_math_flags,
std::move(pre_optimization_hook),
std::move(post_optimization_hook),
std::move(post_codegen_hook))),
*execution_session_, object_layer_,
std::make_unique<CompilerFunctor>(
target_machine_.get(), opt_level, optimize_for_size,
disable_expensive_passes, fast_math_flags,
std::move(pre_optimization_hook),
std::move(post_optimization_hook), std::move(post_codegen_hook))),
main_jit_dylib_(&execution_session_->createBareJITDylib("<main>")),
gdb_jit_event_listener_(
llvm::JITEventListener::createGDBRegistrationListener()) {
VLOG(1) << "CPU target: " << target_machine_->getTargetCPU().str()
<< " features: " << target_machine_->getTargetFeatureString().str();
// Materialize unknown symbols from the runtime symbol table.
class RuntimeSymbolGenerator
: public llvm::orc::JITDylib::DefinitionGenerator {
SimpleOrcJIT& jit_;
public:
explicit RuntimeSymbolGenerator(SimpleOrcJIT& jit) : jit_(jit) {}
llvm::Error tryToGenerate(
llvm::orc::LookupKind, llvm::orc::JITDylib& jit_dylib,
llvm::orc::JITDylibLookupFlags,
const llvm::orc::SymbolLookupSet& names) override {
llvm::orc::SymbolMap new_defs;
for (const auto& kv : names) {
const auto& name = kv.first;
if (llvm::JITEvaluatedSymbol symbol =
jit_.ResolveRuntimeSymbol(*name)) {
new_defs[name] = symbol;
}
}
cantFail(jit_dylib.define(absoluteSymbols(std::move(new_defs))));
return llvm::Error::success();
}
};
main_jit_dylib_->addGenerator(
std::make_unique<RuntimeSymbolGenerator>(*this));
object_layer_.registerJITEventListener(*this);
}
llvm::JITSymbol SimpleOrcJIT::ResolveRuntimeSymbol(const std::string& name) {
llvm::Expected<std::unique_ptr<SimpleOrcJIT>> SimpleOrcJIT::Create(
const llvm::TargetOptions& target_options,
llvm::CodeGenOpt::Level opt_level, bool optimize_for_size,
bool disable_expensive_passes, llvm::FastMathFlags fast_math_flags,
LLVMCompiler::ModuleHook pre_optimization_hook,
LLVMCompiler::ModuleHook post_optimization_hook,
std::function<void(const llvm::object::ObjectFile&)> post_codegen_hook) {
auto target_process_control = llvm::orc::SelfTargetProcessControl::Create();
if (!target_process_control) {
return target_process_control.takeError();
}
auto execution_session = std::make_unique<llvm::orc::ExecutionSession>();
return std::make_unique<SimpleOrcJIT>(
std::move(*target_process_control), std::move(execution_session),
target_options, opt_level, optimize_for_size, disable_expensive_passes,
fast_math_flags, std::move(pre_optimization_hook),
std::move(post_optimization_hook), std::move(post_codegen_hook));
}
llvm::JITEvaluatedSymbol SimpleOrcJIT::ResolveRuntimeSymbol(
llvm::StringRef name) {
void* func_addr = nullptr;
if (name.size() > 1 && name.front() == data_layout_.getGlobalPrefix()) {
// On Mac OS X, 'name' may have a leading underscore prefix, even though the
@ -143,12 +175,13 @@ llvm::JITSymbol SimpleOrcJIT::ResolveRuntimeSymbol(const std::string& name) {
func_addr =
xla::CustomCallTargetRegistry::Global()->Lookup(stripped_name, "Host");
} else {
func_addr = xla::CustomCallTargetRegistry::Global()->Lookup(name, "Host");
func_addr =
xla::CustomCallTargetRegistry::Global()->Lookup(name.str(), "Host");
}
if (func_addr == nullptr) {
LOG(ERROR)
<< "Unable to resolve runtime symbol: `" << name
<< "Unable to resolve runtime symbol: `" << name.str()
<< "'. Hint: if the symbol a custom call target, make sure you've "
"registered it with the JIT using "
"XLA_CPU_REGISTER_CUSTOM_CALL_TARGET.";
@ -159,60 +192,25 @@ llvm::JITSymbol SimpleOrcJIT::ResolveRuntimeSymbol(const std::string& name) {
return symbol_info;
}
void SimpleOrcJIT::NotifyObjectFinalized(
void SimpleOrcJIT::notifyObjectLoaded(
llvm::JITEventListener::ObjectKey key,
const llvm::object::ObjectFile& object,
const llvm::RuntimeDyld::LoadedObjectInfo& object_info) {
uint64_t key = static_cast<uint64_t>(
reinterpret_cast<uintptr_t>(object.getData().data()));
gdb_jit_event_listener_->notifyObjectLoaded(key, object, object_info);
size_of_generated_code_in_bytes_ += object.getData().size();
}
void SimpleOrcJIT::NotifyObjectFreed(const llvm::object::ObjectFile& object) {
uint64_t key = static_cast<uint64_t>(
reinterpret_cast<uintptr_t>(object.getData().data()));
void SimpleOrcJIT::notifyFreeingObject(llvm::JITEventListener::ObjectKey key) {
gdb_jit_event_listener_->notifyFreeingObject(key);
}
SimpleOrcJIT::VModuleKeyT SimpleOrcJIT::AddModule(
std::unique_ptr<llvm::Module> module) {
auto key = execution_session_.allocateVModule();
cantFail(compile_layer_.addModule(key, std::move(module)));
module_keys_.push_back(key);
return key;
llvm::Error SimpleOrcJIT::AddModule(llvm::orc::ThreadSafeModule module) {
return compile_layer_.add(*main_jit_dylib_, std::move(module));
}
void SimpleOrcJIT::RemoveModule(SimpleOrcJIT::VModuleKeyT key) {
module_keys_.erase(std::remove(module_keys_.begin(), module_keys_.end(), key),
module_keys_.end());
cantFail(compile_layer_.removeModule(key));
}
llvm::JITSymbol SimpleOrcJIT::FindCompiledSymbol(const std::string& name) {
#ifdef _WIN32
// The symbol lookup of ObjectLinkingLayer uses the SymbolRef::SF_Exported
// flag to decide whether a symbol will be visible or not, when we call
// IRCompileLayer::findSymbolIn with ExportedSymbolsOnly set to true.
//
// But for Windows COFF objects, this flag is currently never set.
// For a potential solution see: https://reviews.llvm.org/rL258665
// For now, we allow non-exported symbols on Windows as a workaround.
const bool exported_symbols_only = false;
#else
const bool exported_symbols_only = true;
#endif
// Resolve symbol from last module to first, allowing later redefinitions of
// symbols shadow earlier ones.
for (auto& key :
llvm::make_range(module_keys_.rbegin(), module_keys_.rend())) {
if (auto symbol =
compile_layer_.findSymbolIn(key, name, exported_symbols_only)) {
return symbol;
}
}
return nullptr;
llvm::Expected<llvm::JITEvaluatedSymbol> SimpleOrcJIT::FindCompiledSymbol(
const std::string& name) {
return execution_session_->lookup({main_jit_dylib_}, name);
}
#if defined(PLATFORM_WINDOWS)

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h"
#include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h"
#include "llvm/ExecutionEngine/Orc/SymbolStringPool.h"
#include "llvm/ExecutionEngine/Orc/TargetProcessControl.h"
#include "llvm/IR/Module.h"
#include "llvm/Target/TargetMachine.h"
#include "tensorflow/compiler/xla/service/cpu/compiler_functor.h"
@ -42,13 +43,10 @@ namespace cpu {
// Supports JIT-ing multiple modules but without cross-module linking.
// Implements eager compilation - the module is lowered to binary as soon as
// it's added to the JIT.
class SimpleOrcJIT {
class SimpleOrcJIT : public llvm::JITEventListener {
public:
using ObjLayerT = llvm::orc::LegacyRTDyldObjectLinkingLayer;
using CompileFtor =
std::function<llvm::Expected<ObjLayerT::ObjectPtr>(llvm::Module&)>;
using CompileLayerT = llvm::orc::LegacyIRCompileLayer<ObjLayerT, CompileFtor>;
using VModuleKeyT = llvm::orc::VModuleKey;
using ObjLayerT = llvm::orc::RTDyldObjectLinkingLayer;
using CompileLayerT = llvm::orc::IRCompileLayer;
// Create a new JIT, targeting the host architecture.
//
@ -56,6 +54,16 @@ class SimpleOrcJIT {
// LLVM IR-level optimizations. post_codegen_hook is invoked after
// compiling to machine code.
SimpleOrcJIT(
std::unique_ptr<llvm::orc::TargetProcessControl> target_process_control,
std::unique_ptr<llvm::orc::ExecutionSession> execution_session,
const llvm::TargetOptions& target_options,
llvm::CodeGenOpt::Level opt_level, bool optimize_for_size,
bool disable_expensive_passes, llvm::FastMathFlags fast_math_flags,
LLVMCompiler::ModuleHook pre_optimization_hook,
LLVMCompiler::ModuleHook post_optimization_hook,
std::function<void(const llvm::object::ObjectFile&)> post_codegen_hook);
static llvm::Expected<std::unique_ptr<SimpleOrcJIT>> Create(
const llvm::TargetOptions& target_options,
llvm::CodeGenOpt::Level opt_level, bool optimize_for_size,
bool disable_expensive_passes, llvm::FastMathFlags fast_math_flags,
@ -69,16 +77,12 @@ class SimpleOrcJIT {
return target_machine_->getTargetTriple();
}
// Add a module to the JIT. Returns an opaque key that can be used to later
// remove this module.
VModuleKeyT AddModule(std::unique_ptr<llvm::Module> module);
// Remove a module from the JIT and free the memory associated with it.
void RemoveModule(VModuleKeyT key);
llvm::Error AddModule(llvm::orc::ThreadSafeModule module);
// Get the runtime address of the compiled symbol whose name is given. Returns
// nullptr if the symbol cannot be found.
llvm::JITSymbol FindCompiledSymbol(const std::string& name);
llvm::Expected<llvm::JITEvaluatedSymbol> FindCompiledSymbol(
const std::string& name);
llvm::TargetMachine* target_machine() const { return target_machine_.get(); }
@ -93,20 +97,21 @@ class SimpleOrcJIT {
}
private:
llvm::JITSymbol ResolveRuntimeSymbol(const std::string& name);
llvm::JITEvaluatedSymbol ResolveRuntimeSymbol(llvm::StringRef name);
void NotifyObjectFinalized(
void notifyObjectLoaded(
llvm::JITEventListener::ObjectKey key,
const llvm::object::ObjectFile& object,
const llvm::RuntimeDyld::LoadedObjectInfo& object_info);
void NotifyObjectFreed(const llvm::object::ObjectFile& object);
const llvm::RuntimeDyld::LoadedObjectInfo& object_info) override;
void notifyFreeingObject(llvm::JITEventListener::ObjectKey key) override;
std::vector<VModuleKeyT> module_keys_;
std::unique_ptr<llvm::TargetMachine> target_machine_;
const llvm::DataLayout data_layout_;
llvm::orc::ExecutionSession execution_session_;
std::shared_ptr<llvm::orc::SymbolResolver> symbol_resolver_;
std::unique_ptr<llvm::orc::TargetProcessControl> target_process_control_;
std::unique_ptr<llvm::orc::ExecutionSession> execution_session_;
ObjLayerT object_layer_;
CompileLayerT compile_layer_;
llvm::orc::JITDylib* main_jit_dylib_;
int64 size_of_generated_code_in_bytes_ = 0;
// Non owning pointer to a JIT event listener that registers the JIT events