diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 0cc27e32749..a6d365d94a0 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -552,6 +552,7 @@ cc_library( "@llvm-project//llvm:IPO", "@llvm-project//llvm:MC", "@llvm-project//llvm:Object", + "@llvm-project//llvm:OrcJIT", "@llvm-project//llvm:Support", "@llvm-project//llvm:Target", ], diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc index a21ace0d8b2..643de6c4e58 100644 --- a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc +++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc @@ -80,8 +80,8 @@ class FilteredPassManager : public llvm::legacy::PassManager { }; } // anonymous namespace -std::unique_ptr CompilerFunctor::operator()( - llvm::Module& module) const { +llvm::Expected> CompilerFunctor::operator()( + llvm::Module& module) { FilteredPassManager module_passes(disable_expensive_passes_); llvm::legacy::FunctionPassManager function_passes(&module); @@ -155,7 +155,7 @@ std::unique_ptr CompilerFunctor::operator()( } } - return memory_buffer; + return std::move(memory_buffer); } static std::vector VectorFunctionsForTargetLibraryInfoImpl() { diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.h b/tensorflow/compiler/xla/service/cpu/compiler_functor.h index 647f0d18ef5..6211588861b 100644 --- a/tensorflow/compiler/xla/service/cpu/compiler_functor.h +++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_COMPILER_FUNCTOR_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_COMPILER_FUNCTOR_H_ +#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" #include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/Module.h" #include "llvm/IR/Operator.h" @@ -29,7 +30,7 @@ namespace cpu { // Functor class for compiling an LLVM module down to an object file. For use by // Orc JIT compile layer. -class CompilerFunctor { +class CompilerFunctor : public llvm::orc::IRCompileLayer::IRCompiler { public: explicit CompilerFunctor( llvm::TargetMachine* target_machine, int opt_level, @@ -39,7 +40,8 @@ class CompilerFunctor { LLVMCompiler::ModuleHook post_optimization_hook = nullptr, std::function post_codegen_hook = nullptr) - : target_machine_(target_machine), + : IRCompiler(llvm::orc::IRSymbolMapper::ManglingOptions()), + target_machine_(target_machine), opt_level_(opt_level), optimize_for_size_(optimize_for_size), disable_expensive_passes_(disable_expensive_passes), @@ -49,8 +51,8 @@ class CompilerFunctor { post_codegen_hook_(std::move(post_codegen_hook)) {} // Compile a Module to an ObjectFile. - std::unique_ptr operator()( - llvm::Module& module) const; // NOLINT + llvm::Expected> operator()( + llvm::Module& module) override; private: // Populates the given pass manager with TargetLibraryInfo and diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 1ffafd37a27..e92f890ba67 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -38,6 +39,7 @@ limitations under the License. #include "llvm/IR/Verifier.h" #include "llvm/Object/ObjectFile.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/Error.h" #include "llvm/Support/TargetRegistry.h" #include "llvm/Support/TargetSelect.h" #include "llvm/Target/TargetMachine.h" @@ -645,11 +647,11 @@ StatusOr> CpuCompiler::RunBackend( // Compile must be thread-safe so create a new LLVM context for the module. mlir::MLIRContext mlir_context; LoadMLIRDialects(mlir_context); - llvm::LLVMContext llvm_context; + auto llvm_context = std::make_unique(); auto llvm_module = - absl::make_unique("__compute_module", llvm_context); + absl::make_unique("__compute_module", *llvm_context); - auto jit = absl::make_unique( + auto jit = SimpleOrcJIT::Create( CompilerTargetOptions(module->config()), CodeGenOptLevel(module->config()), options::OptimizeForSizeRequested(module->config()), @@ -657,8 +659,12 @@ StatusOr> CpuCompiler::RunBackend( llvm_ir::GetCpuFastMathFlags(module->config()), pre_optimization_ir_hook, post_optimization_ir_hook, OrcJITPostCompilationHook::Create(module.get())); - llvm_module->setDataLayout(jit->data_layout()); - llvm_module->setTargetTriple(jit->target_triple().getTriple()); + if (!jit) { + return InternalError("Creating JIT failed: %s", + llvm::toString(jit.takeError())); + } + llvm_module->setDataLayout((*jit)->data_layout()); + llvm_module->setTargetTriple((*jit)->target_triple().getTriple()); HloComputation* entry_computation = module->entry_computation(); std::unordered_map instruction_to_profile_idx; @@ -700,7 +706,7 @@ StatusOr> CpuCompiler::RunBackend( // GetEmbeddedComputations guarantees that a called computation occurs // before a caller computation. - LLVMTargetMachineFeatures target_machine_features(jit->target_machine()); + LLVMTargetMachineFeatures target_machine_features((*jit)->target_machine()); IrEmitter ir_emitter(&mlir_context, *module, *assignment, llvm_module.get(), std::move(instruction_to_profile_idx), std::move(computation_to_profile_idx), @@ -739,7 +745,7 @@ StatusOr> CpuCompiler::RunBackend( string function_name = [&]() { llvm::SmallVector function_name_vector; llvm::Mangler::getNameWithPrefix( - function_name_vector, entry_function->getName(), jit->data_layout()); + function_name_vector, entry_function->getName(), (*jit)->data_layout()); return string(function_name_vector.begin(), function_name_vector.end()); }(); @@ -751,9 +757,11 @@ StatusOr> CpuCompiler::RunBackend( TF_RETURN_IF_ERROR(VerifyLlvmModule(*llvm_module)); // JIT compile the LLVM IR module to in-memory machine code. - jit->AddModule(std::move(llvm_module)); + llvm::orc::ThreadSafeModule thread_safe_module(std::move(llvm_module), + std::move(llvm_context)); + cantFail((*jit)->AddModule(std::move(thread_safe_module))); cpu_executable.reset(new CpuExecutable( - std::move(jit), std::move(assignment), std::move(module), function_name, + std::move(*jit), std::move(assignment), std::move(module), function_name, std::move(hlo_profile_printer_data), std::move(hlo_profile_index_map))); if (embed_ir_in_executable) { @@ -971,7 +979,7 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, llvm_ir::GetCpuFastMathFlags(module->config()), pre_optimization_ir_hook, post_optimization_ir_hook, post_codegen_hook); std::unique_ptr object_file = - compiler_functor(llvm_module); + cantFail(compiler_functor(llvm_module)); ObjectFileData object_file_data(object_file->getBufferStart(), object_file->getBufferEnd()); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index 02bc445ce9a..5bbf905ce0b 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -63,14 +63,14 @@ CpuExecutable::CpuExecutable( assignment_(std::move(assignment)) { // Resolve symbols in the constructor rather than at execution time to avoid // races because FindSymbol is not thread safe. - llvm::JITSymbol sym = jit_->FindCompiledSymbol(entry_function_name); + llvm::Expected sym = + jit_->FindCompiledSymbol(entry_function_name); // We expect to find the symbol provided with entry_function_name; otherwise // this is an internal error. CHECK(sym) << "Symbol " << entry_function_name << " not found."; // getAddress can do work under the hood in the jit, so it needs to be // guarded by the mutex. - compute_function_ = - reinterpret_cast(cantFail(sym.getAddress())); + compute_function_ = reinterpret_cast(sym->getAddress()); VLOG(1) << "compute_function_ at address " << reinterpret_cast(compute_function_); } diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index 28508bde4cd..5556ee7c467 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -85,6 +85,8 @@ SimpleOrcJIT::InferTargetMachineForJIT( } SimpleOrcJIT::SimpleOrcJIT( + std::unique_ptr target_process_control, + std::unique_ptr execution_session, const llvm::TargetOptions& target_options, llvm::CodeGenOpt::Level opt_level, bool optimize_for_size, bool disable_expensive_passes, llvm::FastMathFlags fast_math_flags, @@ -93,48 +95,78 @@ SimpleOrcJIT::SimpleOrcJIT( std::function post_codegen_hook) : target_machine_(InferTargetMachineForJIT(target_options, opt_level)), data_layout_(target_machine_->createDataLayout()), - symbol_resolver_(llvm::orc::createLegacyLookupResolver( - execution_session_, - [this](llvm::StringRef name) -> llvm::JITSymbol { - return this->ResolveRuntimeSymbol(std::string(name)); - }, - [](llvm::Error Err) { - cantFail(std::move(Err), "lookupFlags failed"); - })), - object_layer_( - execution_session_, - [this](llvm::orc::VModuleKey) { - llvm::orc::LegacyRTDyldObjectLinkingLayer::Resources result; - result.MemMgr = std::make_shared( - orc_jit_memory_mapper::GetInstance()); - result.Resolver = symbol_resolver_; - return result; - }, - /*NotifyLoaded=*/ - llvm::orc::LegacyRTDyldObjectLinkingLayer::NotifyLoadedFtor(), - /*NotifyFinalized=*/ - [this](VModuleKeyT, const llvm::object::ObjectFile& object, - const llvm::RuntimeDyld::LoadedObjectInfo& object_info) { - this->NotifyObjectFinalized(object, object_info); - }, - /*NotifyFreed=*/ - [this](VModuleKeyT, const llvm::object::ObjectFile& object) { - this->NotifyObjectFreed(object); - }), + target_process_control_(std::move(target_process_control)), + execution_session_(std::move(execution_session)), + object_layer_(*execution_session_, + []() { + return std::make_unique( + orc_jit_memory_mapper::GetInstance()); + }), compile_layer_( - object_layer_, - CompilerFunctor(target_machine_.get(), opt_level, optimize_for_size, - disable_expensive_passes, fast_math_flags, - std::move(pre_optimization_hook), - std::move(post_optimization_hook), - std::move(post_codegen_hook))), + *execution_session_, object_layer_, + std::make_unique( + target_machine_.get(), opt_level, optimize_for_size, + disable_expensive_passes, fast_math_flags, + std::move(pre_optimization_hook), + std::move(post_optimization_hook), std::move(post_codegen_hook))), + main_jit_dylib_(&execution_session_->createBareJITDylib("
")), gdb_jit_event_listener_( llvm::JITEventListener::createGDBRegistrationListener()) { VLOG(1) << "CPU target: " << target_machine_->getTargetCPU().str() << " features: " << target_machine_->getTargetFeatureString().str(); + + // Materialize unknown symbols from the runtime symbol table. + class RuntimeSymbolGenerator + : public llvm::orc::JITDylib::DefinitionGenerator { + SimpleOrcJIT& jit_; + + public: + explicit RuntimeSymbolGenerator(SimpleOrcJIT& jit) : jit_(jit) {} + llvm::Error tryToGenerate( + llvm::orc::LookupKind, llvm::orc::JITDylib& jit_dylib, + llvm::orc::JITDylibLookupFlags, + const llvm::orc::SymbolLookupSet& names) override { + llvm::orc::SymbolMap new_defs; + + for (const auto& kv : names) { + const auto& name = kv.first; + if (llvm::JITEvaluatedSymbol symbol = + jit_.ResolveRuntimeSymbol(*name)) { + new_defs[name] = symbol; + } + } + + cantFail(jit_dylib.define(absoluteSymbols(std::move(new_defs)))); + return llvm::Error::success(); + } + }; + main_jit_dylib_->addGenerator( + std::make_unique(*this)); + object_layer_.registerJITEventListener(*this); } -llvm::JITSymbol SimpleOrcJIT::ResolveRuntimeSymbol(const std::string& name) { +llvm::Expected> SimpleOrcJIT::Create( + const llvm::TargetOptions& target_options, + llvm::CodeGenOpt::Level opt_level, bool optimize_for_size, + bool disable_expensive_passes, llvm::FastMathFlags fast_math_flags, + LLVMCompiler::ModuleHook pre_optimization_hook, + LLVMCompiler::ModuleHook post_optimization_hook, + std::function post_codegen_hook) { + auto target_process_control = llvm::orc::SelfTargetProcessControl::Create(); + if (!target_process_control) { + return target_process_control.takeError(); + } + + auto execution_session = std::make_unique(); + return std::make_unique( + std::move(*target_process_control), std::move(execution_session), + target_options, opt_level, optimize_for_size, disable_expensive_passes, + fast_math_flags, std::move(pre_optimization_hook), + std::move(post_optimization_hook), std::move(post_codegen_hook)); +} + +llvm::JITEvaluatedSymbol SimpleOrcJIT::ResolveRuntimeSymbol( + llvm::StringRef name) { void* func_addr = nullptr; if (name.size() > 1 && name.front() == data_layout_.getGlobalPrefix()) { // On Mac OS X, 'name' may have a leading underscore prefix, even though the @@ -143,12 +175,13 @@ llvm::JITSymbol SimpleOrcJIT::ResolveRuntimeSymbol(const std::string& name) { func_addr = xla::CustomCallTargetRegistry::Global()->Lookup(stripped_name, "Host"); } else { - func_addr = xla::CustomCallTargetRegistry::Global()->Lookup(name, "Host"); + func_addr = + xla::CustomCallTargetRegistry::Global()->Lookup(name.str(), "Host"); } if (func_addr == nullptr) { LOG(ERROR) - << "Unable to resolve runtime symbol: `" << name + << "Unable to resolve runtime symbol: `" << name.str() << "'. Hint: if the symbol a custom call target, make sure you've " "registered it with the JIT using " "XLA_CPU_REGISTER_CUSTOM_CALL_TARGET."; @@ -159,60 +192,25 @@ llvm::JITSymbol SimpleOrcJIT::ResolveRuntimeSymbol(const std::string& name) { return symbol_info; } -void SimpleOrcJIT::NotifyObjectFinalized( +void SimpleOrcJIT::notifyObjectLoaded( + llvm::JITEventListener::ObjectKey key, const llvm::object::ObjectFile& object, const llvm::RuntimeDyld::LoadedObjectInfo& object_info) { - uint64_t key = static_cast( - reinterpret_cast(object.getData().data())); gdb_jit_event_listener_->notifyObjectLoaded(key, object, object_info); size_of_generated_code_in_bytes_ += object.getData().size(); } -void SimpleOrcJIT::NotifyObjectFreed(const llvm::object::ObjectFile& object) { - uint64_t key = static_cast( - reinterpret_cast(object.getData().data())); +void SimpleOrcJIT::notifyFreeingObject(llvm::JITEventListener::ObjectKey key) { gdb_jit_event_listener_->notifyFreeingObject(key); } -SimpleOrcJIT::VModuleKeyT SimpleOrcJIT::AddModule( - std::unique_ptr module) { - auto key = execution_session_.allocateVModule(); - cantFail(compile_layer_.addModule(key, std::move(module))); - module_keys_.push_back(key); - return key; +llvm::Error SimpleOrcJIT::AddModule(llvm::orc::ThreadSafeModule module) { + return compile_layer_.add(*main_jit_dylib_, std::move(module)); } -void SimpleOrcJIT::RemoveModule(SimpleOrcJIT::VModuleKeyT key) { - module_keys_.erase(std::remove(module_keys_.begin(), module_keys_.end(), key), - module_keys_.end()); - cantFail(compile_layer_.removeModule(key)); -} - -llvm::JITSymbol SimpleOrcJIT::FindCompiledSymbol(const std::string& name) { -#ifdef _WIN32 - // The symbol lookup of ObjectLinkingLayer uses the SymbolRef::SF_Exported - // flag to decide whether a symbol will be visible or not, when we call - // IRCompileLayer::findSymbolIn with ExportedSymbolsOnly set to true. - // - // But for Windows COFF objects, this flag is currently never set. - // For a potential solution see: https://reviews.llvm.org/rL258665 - // For now, we allow non-exported symbols on Windows as a workaround. - const bool exported_symbols_only = false; -#else - const bool exported_symbols_only = true; -#endif - - // Resolve symbol from last module to first, allowing later redefinitions of - // symbols shadow earlier ones. - for (auto& key : - llvm::make_range(module_keys_.rbegin(), module_keys_.rend())) { - if (auto symbol = - compile_layer_.findSymbolIn(key, name, exported_symbols_only)) { - return symbol; - } - } - - return nullptr; +llvm::Expected SimpleOrcJIT::FindCompiledSymbol( + const std::string& name) { + return execution_session_->lookup({main_jit_dylib_}, name); } #if defined(PLATFORM_WINDOWS) diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h index 9c470edbac2..714df6b0f87 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h @@ -26,6 +26,7 @@ limitations under the License. #include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" #include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" #include "llvm/ExecutionEngine/Orc/SymbolStringPool.h" +#include "llvm/ExecutionEngine/Orc/TargetProcessControl.h" #include "llvm/IR/Module.h" #include "llvm/Target/TargetMachine.h" #include "tensorflow/compiler/xla/service/cpu/compiler_functor.h" @@ -42,13 +43,10 @@ namespace cpu { // Supports JIT-ing multiple modules but without cross-module linking. // Implements eager compilation - the module is lowered to binary as soon as // it's added to the JIT. -class SimpleOrcJIT { +class SimpleOrcJIT : public llvm::JITEventListener { public: - using ObjLayerT = llvm::orc::LegacyRTDyldObjectLinkingLayer; - using CompileFtor = - std::function(llvm::Module&)>; - using CompileLayerT = llvm::orc::LegacyIRCompileLayer; - using VModuleKeyT = llvm::orc::VModuleKey; + using ObjLayerT = llvm::orc::RTDyldObjectLinkingLayer; + using CompileLayerT = llvm::orc::IRCompileLayer; // Create a new JIT, targeting the host architecture. // @@ -56,6 +54,16 @@ class SimpleOrcJIT { // LLVM IR-level optimizations. post_codegen_hook is invoked after // compiling to machine code. SimpleOrcJIT( + std::unique_ptr target_process_control, + std::unique_ptr execution_session, + const llvm::TargetOptions& target_options, + llvm::CodeGenOpt::Level opt_level, bool optimize_for_size, + bool disable_expensive_passes, llvm::FastMathFlags fast_math_flags, + LLVMCompiler::ModuleHook pre_optimization_hook, + LLVMCompiler::ModuleHook post_optimization_hook, + std::function post_codegen_hook); + + static llvm::Expected> Create( const llvm::TargetOptions& target_options, llvm::CodeGenOpt::Level opt_level, bool optimize_for_size, bool disable_expensive_passes, llvm::FastMathFlags fast_math_flags, @@ -69,16 +77,12 @@ class SimpleOrcJIT { return target_machine_->getTargetTriple(); } - // Add a module to the JIT. Returns an opaque key that can be used to later - // remove this module. - VModuleKeyT AddModule(std::unique_ptr module); - - // Remove a module from the JIT and free the memory associated with it. - void RemoveModule(VModuleKeyT key); + llvm::Error AddModule(llvm::orc::ThreadSafeModule module); // Get the runtime address of the compiled symbol whose name is given. Returns // nullptr if the symbol cannot be found. - llvm::JITSymbol FindCompiledSymbol(const std::string& name); + llvm::Expected FindCompiledSymbol( + const std::string& name); llvm::TargetMachine* target_machine() const { return target_machine_.get(); } @@ -93,20 +97,21 @@ class SimpleOrcJIT { } private: - llvm::JITSymbol ResolveRuntimeSymbol(const std::string& name); + llvm::JITEvaluatedSymbol ResolveRuntimeSymbol(llvm::StringRef name); - void NotifyObjectFinalized( + void notifyObjectLoaded( + llvm::JITEventListener::ObjectKey key, const llvm::object::ObjectFile& object, - const llvm::RuntimeDyld::LoadedObjectInfo& object_info); - void NotifyObjectFreed(const llvm::object::ObjectFile& object); + const llvm::RuntimeDyld::LoadedObjectInfo& object_info) override; + void notifyFreeingObject(llvm::JITEventListener::ObjectKey key) override; - std::vector module_keys_; std::unique_ptr target_machine_; const llvm::DataLayout data_layout_; - llvm::orc::ExecutionSession execution_session_; - std::shared_ptr symbol_resolver_; + std::unique_ptr target_process_control_; + std::unique_ptr execution_session_; ObjLayerT object_layer_; CompileLayerT compile_layer_; + llvm::orc::JITDylib* main_jit_dylib_; int64 size_of_generated_code_in_bytes_ = 0; // Non owning pointer to a JIT event listener that registers the JIT events