Add TTI pass initialization to pass managers.

Many LLVM transformations benefits from knowing the targets. This enables optimizations,
especially in a JIT context when the target is (generally) well-known.

Closes #49

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/49 from dcaballe:dcaballe/tti ab02f72eb326f660945696e5dadeeb983cf263b3
PiperOrigin-RevId: 261840617
This commit is contained in:
Diego Caballero 2019-08-05 22:13:56 -07:00 committed by TensorFlower Gardener
parent 4623b733f7
commit d18653005c
3 changed files with 37 additions and 14 deletions

View File

@ -31,6 +31,7 @@
namespace llvm { namespace llvm {
class Module; class Module;
class Error; class Error;
class TargetMachine;
} // namespace llvm } // namespace llvm
namespace mlir { namespace mlir {
@ -41,17 +42,23 @@ void initializeLLVMPasses();
/// Create a module transformer function for MLIR ExecutionEngine that runs /// Create a module transformer function for MLIR ExecutionEngine that runs
/// LLVM IR passes corresponding to the given speed and size optimization /// LLVM IR passes corresponding to the given speed and size optimization
/// levels (e.g. -O2 or -Os). /// levels (e.g. -O2 or -Os). If not null, `targetMachine` is used to
/// initialize passes that provide target-specific information to the LLVM
/// optimizer. `targetMachine` must outlive the returned std::function.
std::function<llvm::Error(llvm::Module *)> std::function<llvm::Error(llvm::Module *)>
makeOptimizingTransformer(unsigned optLevel, unsigned sizeLevel); makeOptimizingTransformer(unsigned optLevel, unsigned sizeLevel,
llvm::TargetMachine *targetMachine);
/// Create a module transformer function for MLIR ExecutionEngine that runs /// Create a module transformer function for MLIR ExecutionEngine that runs
/// LLVM IR passes explicitly specified, plus an optional optimization level, /// LLVM IR passes explicitly specified, plus an optional optimization level,
/// Any optimization passes, if present, will be inserted before the pass at /// Any optimization passes, if present, will be inserted before the pass at
/// position optPassesInsertPos. /// position optPassesInsertPos. If not null, `targetMachine` is used to
/// initialize passes that provide target-specific information to the LLVM
/// optimizer. `targetMachine` must outlive the returned std::function.
std::function<llvm::Error(llvm::Module *)> std::function<llvm::Error(llvm::Module *)>
makeLLVMPassesTransformer(llvm::ArrayRef<const llvm::PassInfo *> llvmPasses, makeLLVMPassesTransformer(llvm::ArrayRef<const llvm::PassInfo *> llvmPasses,
llvm::Optional<unsigned> mbOptLevel, llvm::Optional<unsigned> mbOptLevel,
llvm::TargetMachine *targetMachine,
unsigned optPassesInsertPos = 0); unsigned optPassesInsertPos = 0);
} // end namespace mlir } // end namespace mlir

View File

@ -23,6 +23,7 @@
#include "mlir/ExecutionEngine/OptUtils.h" #include "mlir/ExecutionEngine/OptUtils.h"
#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/ArrayRef.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/LegacyPassNameParser.h" #include "llvm/IR/LegacyPassNameParser.h"
#include "llvm/IR/Module.h" #include "llvm/IR/Module.h"
@ -32,6 +33,7 @@
#include "llvm/Support/CommandLine.h" #include "llvm/Support/CommandLine.h"
#include "llvm/Support/Error.h" #include "llvm/Support/Error.h"
#include "llvm/Support/StringSaver.h" #include "llvm/Support/StringSaver.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Transforms/IPO.h" #include "llvm/Transforms/IPO.h"
#include "llvm/Transforms/IPO/PassManagerBuilder.h" #include "llvm/Transforms/IPO/PassManagerBuilder.h"
#include <climits> #include <climits>
@ -69,7 +71,8 @@ void mlir::initializeLLVMPasses() {
// This behaves similarly to LLVM opt. // This behaves similarly to LLVM opt.
static void populatePassManagers(llvm::legacy::PassManager &modulePM, static void populatePassManagers(llvm::legacy::PassManager &modulePM,
llvm::legacy::FunctionPassManager &funcPM, llvm::legacy::FunctionPassManager &funcPM,
unsigned optLevel, unsigned sizeLevel) { unsigned optLevel, unsigned sizeLevel,
llvm::TargetMachine *targetMachine) {
llvm::PassManagerBuilder builder; llvm::PassManagerBuilder builder;
builder.OptLevel = optLevel; builder.OptLevel = optLevel;
builder.SizeLevel = sizeLevel; builder.SizeLevel = sizeLevel;
@ -79,6 +82,15 @@ static void populatePassManagers(llvm::legacy::PassManager &modulePM,
builder.SLPVectorize = optLevel > 1 && sizeLevel < 2; builder.SLPVectorize = optLevel > 1 && sizeLevel < 2;
builder.DisableUnrollLoops = (optLevel == 0); builder.DisableUnrollLoops = (optLevel == 0);
if (targetMachine) {
// Add pass to initialize TTI for this specific target. Otherwise, TTI will
// be initialized to NoTTIImpl by defaul.
modulePM.add(createTargetTransformInfoWrapperPass(
targetMachine->getTargetIRAnalysis()));
funcPM.add(createTargetTransformInfoWrapperPass(
targetMachine->getTargetIRAnalysis()));
}
builder.populateModulePassManager(modulePM); builder.populateModulePassManager(modulePM);
builder.populateFunctionPassManager(funcPM); builder.populateFunctionPassManager(funcPM);
} }
@ -86,11 +98,12 @@ static void populatePassManagers(llvm::legacy::PassManager &modulePM,
// Create and return a lambda that uses LLVM pass manager builder to set up // Create and return a lambda that uses LLVM pass manager builder to set up
// optimizations based on the given level. // optimizations based on the given level.
std::function<llvm::Error(llvm::Module *)> std::function<llvm::Error(llvm::Module *)>
mlir::makeOptimizingTransformer(unsigned optLevel, unsigned sizeLevel) { mlir::makeOptimizingTransformer(unsigned optLevel, unsigned sizeLevel,
return [optLevel, sizeLevel](llvm::Module *m) -> llvm::Error { llvm::TargetMachine *targetMachine) {
return [optLevel, sizeLevel, targetMachine](llvm::Module *m) -> llvm::Error {
llvm::legacy::PassManager modulePM; llvm::legacy::PassManager modulePM;
llvm::legacy::FunctionPassManager funcPM(m); llvm::legacy::FunctionPassManager funcPM(m);
populatePassManagers(modulePM, funcPM, optLevel, sizeLevel); populatePassManagers(modulePM, funcPM, optLevel, sizeLevel, targetMachine);
runPasses(modulePM, funcPM, *m); runPasses(modulePM, funcPM, *m);
return llvm::Error::success(); return llvm::Error::success();
@ -101,9 +114,10 @@ mlir::makeOptimizingTransformer(unsigned optLevel, unsigned sizeLevel) {
// optional optimization level to pre-populate the pass manager. // optional optimization level to pre-populate the pass manager.
std::function<llvm::Error(llvm::Module *)> mlir::makeLLVMPassesTransformer( std::function<llvm::Error(llvm::Module *)> mlir::makeLLVMPassesTransformer(
llvm::ArrayRef<const llvm::PassInfo *> llvmPasses, llvm::ArrayRef<const llvm::PassInfo *> llvmPasses,
llvm::Optional<unsigned> mbOptLevel, unsigned optPassesInsertPos) { llvm::Optional<unsigned> mbOptLevel, llvm::TargetMachine *targetMachine,
return [llvmPasses, mbOptLevel, unsigned optPassesInsertPos) {
optPassesInsertPos](llvm::Module *m) -> llvm::Error { return [llvmPasses, mbOptLevel, optPassesInsertPos,
targetMachine](llvm::Module *m) -> llvm::Error {
llvm::legacy::PassManager modulePM; llvm::legacy::PassManager modulePM;
llvm::legacy::FunctionPassManager funcPM(m); llvm::legacy::FunctionPassManager funcPM(m);
@ -114,7 +128,8 @@ std::function<llvm::Error(llvm::Module *)> mlir::makeLLVMPassesTransformer(
continue; continue;
if (insertOptPasses && optPassesInsertPos == i) { if (insertOptPasses && optPassesInsertPos == i) {
populatePassManagers(modulePM, funcPM, mbOptLevel.getValue(), 0); populatePassManagers(modulePM, funcPM, mbOptLevel.getValue(), 0,
targetMachine);
insertOptPasses = false; insertOptPasses = false;
} }
@ -127,7 +142,8 @@ std::function<llvm::Error(llvm::Module *)> mlir::makeLLVMPassesTransformer(
} }
if (insertOptPasses) if (insertOptPasses)
populatePassManagers(modulePM, funcPM, mbOptLevel.getValue(), 0); populatePassManagers(modulePM, funcPM, mbOptLevel.getValue(), 0,
targetMachine);
runPasses(modulePM, funcPM, *m); runPasses(modulePM, funcPM, *m);
return llvm::Error::success(); return llvm::Error::success();

View File

@ -308,8 +308,8 @@ int mlir::JitRunnerMain(
if (failed(mlirTransformer(m.get()))) if (failed(mlirTransformer(m.get())))
return EXIT_FAILURE; return EXIT_FAILURE;
auto transformer = auto transformer = mlir::makeLLVMPassesTransformer(
mlir::makeLLVMPassesTransformer(passes, optLevel, optPosition); passes, optLevel, /*targetMachine=*/nullptr, optPosition);
auto error = mainFuncType.getValue() == "f32" auto error = mainFuncType.getValue() == "f32"
? compileAndExecuteSingleFloatReturnFunction( ? compileAndExecuteSingleFloatReturnFunction(
m.get(), mainFuncName.getValue(), transformer) m.get(), mainFuncName.getValue(), transformer)