From d18653005c8ff077492b541538538de0b2ec55ac Mon Sep 17 00:00:00 2001 From: Diego Caballero Date: Mon, 5 Aug 2019 22:13:56 -0700 Subject: [PATCH] Add TTI pass initialization to pass managers. Many LLVM transformations benefits from knowing the targets. This enables optimizations, especially in a JIT context when the target is (generally) well-known. Closes #49 COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/49 from dcaballe:dcaballe/tti ab02f72eb326f660945696e5dadeeb983cf263b3 PiperOrigin-RevId: 261840617 --- .../include/mlir/ExecutionEngine/OptUtils.h | 13 +++++-- .../mlir/lib/ExecutionEngine/OptUtils.cpp | 34 ++++++++++++++----- third_party/mlir/lib/Support/JitRunner.cpp | 4 +-- 3 files changed, 37 insertions(+), 14 deletions(-) diff --git a/third_party/mlir/include/mlir/ExecutionEngine/OptUtils.h b/third_party/mlir/include/mlir/ExecutionEngine/OptUtils.h index 86ca212e9a2..8c0249d5c09 100644 --- a/third_party/mlir/include/mlir/ExecutionEngine/OptUtils.h +++ b/third_party/mlir/include/mlir/ExecutionEngine/OptUtils.h @@ -31,6 +31,7 @@ namespace llvm { class Module; class Error; +class TargetMachine; } // namespace llvm namespace mlir { @@ -41,17 +42,23 @@ void initializeLLVMPasses(); /// Create a module transformer function for MLIR ExecutionEngine that runs /// LLVM IR passes corresponding to the given speed and size optimization -/// levels (e.g. -O2 or -Os). +/// levels (e.g. -O2 or -Os). If not null, `targetMachine` is used to +/// initialize passes that provide target-specific information to the LLVM +/// optimizer. `targetMachine` must outlive the returned std::function. std::function -makeOptimizingTransformer(unsigned optLevel, unsigned sizeLevel); +makeOptimizingTransformer(unsigned optLevel, unsigned sizeLevel, + llvm::TargetMachine *targetMachine); /// Create a module transformer function for MLIR ExecutionEngine that runs /// LLVM IR passes explicitly specified, plus an optional optimization level, /// Any optimization passes, if present, will be inserted before the pass at -/// position optPassesInsertPos. +/// position optPassesInsertPos. If not null, `targetMachine` is used to +/// initialize passes that provide target-specific information to the LLVM +/// optimizer. `targetMachine` must outlive the returned std::function. std::function makeLLVMPassesTransformer(llvm::ArrayRef llvmPasses, llvm::Optional mbOptLevel, + llvm::TargetMachine *targetMachine, unsigned optPassesInsertPos = 0); } // end namespace mlir diff --git a/third_party/mlir/lib/ExecutionEngine/OptUtils.cpp b/third_party/mlir/lib/ExecutionEngine/OptUtils.cpp index 7831d67c62b..e8c6652f446 100644 --- a/third_party/mlir/lib/ExecutionEngine/OptUtils.cpp +++ b/third_party/mlir/lib/ExecutionEngine/OptUtils.cpp @@ -23,6 +23,7 @@ #include "mlir/ExecutionEngine/OptUtils.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/LegacyPassNameParser.h" #include "llvm/IR/Module.h" @@ -32,6 +33,7 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/Error.h" #include "llvm/Support/StringSaver.h" +#include "llvm/Target/TargetMachine.h" #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/IPO/PassManagerBuilder.h" #include @@ -69,7 +71,8 @@ void mlir::initializeLLVMPasses() { // This behaves similarly to LLVM opt. static void populatePassManagers(llvm::legacy::PassManager &modulePM, llvm::legacy::FunctionPassManager &funcPM, - unsigned optLevel, unsigned sizeLevel) { + unsigned optLevel, unsigned sizeLevel, + llvm::TargetMachine *targetMachine) { llvm::PassManagerBuilder builder; builder.OptLevel = optLevel; builder.SizeLevel = sizeLevel; @@ -79,6 +82,15 @@ static void populatePassManagers(llvm::legacy::PassManager &modulePM, builder.SLPVectorize = optLevel > 1 && sizeLevel < 2; builder.DisableUnrollLoops = (optLevel == 0); + if (targetMachine) { + // Add pass to initialize TTI for this specific target. Otherwise, TTI will + // be initialized to NoTTIImpl by defaul. + modulePM.add(createTargetTransformInfoWrapperPass( + targetMachine->getTargetIRAnalysis())); + funcPM.add(createTargetTransformInfoWrapperPass( + targetMachine->getTargetIRAnalysis())); + } + builder.populateModulePassManager(modulePM); builder.populateFunctionPassManager(funcPM); } @@ -86,11 +98,12 @@ static void populatePassManagers(llvm::legacy::PassManager &modulePM, // Create and return a lambda that uses LLVM pass manager builder to set up // optimizations based on the given level. std::function -mlir::makeOptimizingTransformer(unsigned optLevel, unsigned sizeLevel) { - return [optLevel, sizeLevel](llvm::Module *m) -> llvm::Error { +mlir::makeOptimizingTransformer(unsigned optLevel, unsigned sizeLevel, + llvm::TargetMachine *targetMachine) { + return [optLevel, sizeLevel, targetMachine](llvm::Module *m) -> llvm::Error { llvm::legacy::PassManager modulePM; llvm::legacy::FunctionPassManager funcPM(m); - populatePassManagers(modulePM, funcPM, optLevel, sizeLevel); + populatePassManagers(modulePM, funcPM, optLevel, sizeLevel, targetMachine); runPasses(modulePM, funcPM, *m); return llvm::Error::success(); @@ -101,9 +114,10 @@ mlir::makeOptimizingTransformer(unsigned optLevel, unsigned sizeLevel) { // optional optimization level to pre-populate the pass manager. std::function mlir::makeLLVMPassesTransformer( llvm::ArrayRef llvmPasses, - llvm::Optional mbOptLevel, unsigned optPassesInsertPos) { - return [llvmPasses, mbOptLevel, - optPassesInsertPos](llvm::Module *m) -> llvm::Error { + llvm::Optional mbOptLevel, llvm::TargetMachine *targetMachine, + unsigned optPassesInsertPos) { + return [llvmPasses, mbOptLevel, optPassesInsertPos, + targetMachine](llvm::Module *m) -> llvm::Error { llvm::legacy::PassManager modulePM; llvm::legacy::FunctionPassManager funcPM(m); @@ -114,7 +128,8 @@ std::function mlir::makeLLVMPassesTransformer( continue; if (insertOptPasses && optPassesInsertPos == i) { - populatePassManagers(modulePM, funcPM, mbOptLevel.getValue(), 0); + populatePassManagers(modulePM, funcPM, mbOptLevel.getValue(), 0, + targetMachine); insertOptPasses = false; } @@ -127,7 +142,8 @@ std::function mlir::makeLLVMPassesTransformer( } if (insertOptPasses) - populatePassManagers(modulePM, funcPM, mbOptLevel.getValue(), 0); + populatePassManagers(modulePM, funcPM, mbOptLevel.getValue(), 0, + targetMachine); runPasses(modulePM, funcPM, *m); return llvm::Error::success(); diff --git a/third_party/mlir/lib/Support/JitRunner.cpp b/third_party/mlir/lib/Support/JitRunner.cpp index 1c6df7c5be8..56058c27be3 100644 --- a/third_party/mlir/lib/Support/JitRunner.cpp +++ b/third_party/mlir/lib/Support/JitRunner.cpp @@ -308,8 +308,8 @@ int mlir::JitRunnerMain( if (failed(mlirTransformer(m.get()))) return EXIT_FAILURE; - auto transformer = - mlir::makeLLVMPassesTransformer(passes, optLevel, optPosition); + auto transformer = mlir::makeLLVMPassesTransformer( + passes, optLevel, /*targetMachine=*/nullptr, optPosition); auto error = mainFuncType.getValue() == "f32" ? compileAndExecuteSingleFloatReturnFunction( m.get(), mainFuncName.getValue(), transformer)