Add TTI pass initialization to pass managers.

Many LLVM transformations benefits from knowing the targets. This enables optimizations,
especially in a JIT context when the target is (generally) well-known.

Closes #49

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/49 from dcaballe:dcaballe/tti ab02f72eb326f660945696e5dadeeb983cf263b3
PiperOrigin-RevId: 261840617
This commit is contained in:
Diego Caballero 2019-08-05 22:13:56 -07:00 committed by TensorFlower Gardener
parent 4623b733f7
commit d18653005c
3 changed files with 37 additions and 14 deletions

View File

@ -31,6 +31,7 @@
namespace llvm {
class Module;
class Error;
class TargetMachine;
} // namespace llvm
namespace mlir {
@ -41,17 +42,23 @@ void initializeLLVMPasses();
/// Create a module transformer function for MLIR ExecutionEngine that runs
/// LLVM IR passes corresponding to the given speed and size optimization
/// levels (e.g. -O2 or -Os).
/// levels (e.g. -O2 or -Os). If not null, `targetMachine` is used to
/// initialize passes that provide target-specific information to the LLVM
/// optimizer. `targetMachine` must outlive the returned std::function.
std::function<llvm::Error(llvm::Module *)>
makeOptimizingTransformer(unsigned optLevel, unsigned sizeLevel);
makeOptimizingTransformer(unsigned optLevel, unsigned sizeLevel,
llvm::TargetMachine *targetMachine);
/// Create a module transformer function for MLIR ExecutionEngine that runs
/// LLVM IR passes explicitly specified, plus an optional optimization level,
/// Any optimization passes, if present, will be inserted before the pass at
/// position optPassesInsertPos.
/// position optPassesInsertPos. If not null, `targetMachine` is used to
/// initialize passes that provide target-specific information to the LLVM
/// optimizer. `targetMachine` must outlive the returned std::function.
std::function<llvm::Error(llvm::Module *)>
makeLLVMPassesTransformer(llvm::ArrayRef<const llvm::PassInfo *> llvmPasses,
llvm::Optional<unsigned> mbOptLevel,
llvm::TargetMachine *targetMachine,
unsigned optPassesInsertPos = 0);
} // end namespace mlir

View File

@ -23,6 +23,7 @@
#include "mlir/ExecutionEngine/OptUtils.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/LegacyPassNameParser.h"
#include "llvm/IR/Module.h"
@ -32,6 +33,7 @@
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/StringSaver.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Transforms/IPO.h"
#include "llvm/Transforms/IPO/PassManagerBuilder.h"
#include <climits>
@ -69,7 +71,8 @@ void mlir::initializeLLVMPasses() {
// This behaves similarly to LLVM opt.
static void populatePassManagers(llvm::legacy::PassManager &modulePM,
llvm::legacy::FunctionPassManager &funcPM,
unsigned optLevel, unsigned sizeLevel) {
unsigned optLevel, unsigned sizeLevel,
llvm::TargetMachine *targetMachine) {
llvm::PassManagerBuilder builder;
builder.OptLevel = optLevel;
builder.SizeLevel = sizeLevel;
@ -79,6 +82,15 @@ static void populatePassManagers(llvm::legacy::PassManager &modulePM,
builder.SLPVectorize = optLevel > 1 && sizeLevel < 2;
builder.DisableUnrollLoops = (optLevel == 0);
if (targetMachine) {
// Add pass to initialize TTI for this specific target. Otherwise, TTI will
// be initialized to NoTTIImpl by defaul.
modulePM.add(createTargetTransformInfoWrapperPass(
targetMachine->getTargetIRAnalysis()));
funcPM.add(createTargetTransformInfoWrapperPass(
targetMachine->getTargetIRAnalysis()));
}
builder.populateModulePassManager(modulePM);
builder.populateFunctionPassManager(funcPM);
}
@ -86,11 +98,12 @@ static void populatePassManagers(llvm::legacy::PassManager &modulePM,
// Create and return a lambda that uses LLVM pass manager builder to set up
// optimizations based on the given level.
std::function<llvm::Error(llvm::Module *)>
mlir::makeOptimizingTransformer(unsigned optLevel, unsigned sizeLevel) {
return [optLevel, sizeLevel](llvm::Module *m) -> llvm::Error {
mlir::makeOptimizingTransformer(unsigned optLevel, unsigned sizeLevel,
llvm::TargetMachine *targetMachine) {
return [optLevel, sizeLevel, targetMachine](llvm::Module *m) -> llvm::Error {
llvm::legacy::PassManager modulePM;
llvm::legacy::FunctionPassManager funcPM(m);
populatePassManagers(modulePM, funcPM, optLevel, sizeLevel);
populatePassManagers(modulePM, funcPM, optLevel, sizeLevel, targetMachine);
runPasses(modulePM, funcPM, *m);
return llvm::Error::success();
@ -101,9 +114,10 @@ mlir::makeOptimizingTransformer(unsigned optLevel, unsigned sizeLevel) {
// optional optimization level to pre-populate the pass manager.
std::function<llvm::Error(llvm::Module *)> mlir::makeLLVMPassesTransformer(
llvm::ArrayRef<const llvm::PassInfo *> llvmPasses,
llvm::Optional<unsigned> mbOptLevel, unsigned optPassesInsertPos) {
return [llvmPasses, mbOptLevel,
optPassesInsertPos](llvm::Module *m) -> llvm::Error {
llvm::Optional<unsigned> mbOptLevel, llvm::TargetMachine *targetMachine,
unsigned optPassesInsertPos) {
return [llvmPasses, mbOptLevel, optPassesInsertPos,
targetMachine](llvm::Module *m) -> llvm::Error {
llvm::legacy::PassManager modulePM;
llvm::legacy::FunctionPassManager funcPM(m);
@ -114,7 +128,8 @@ std::function<llvm::Error(llvm::Module *)> mlir::makeLLVMPassesTransformer(
continue;
if (insertOptPasses && optPassesInsertPos == i) {
populatePassManagers(modulePM, funcPM, mbOptLevel.getValue(), 0);
populatePassManagers(modulePM, funcPM, mbOptLevel.getValue(), 0,
targetMachine);
insertOptPasses = false;
}
@ -127,7 +142,8 @@ std::function<llvm::Error(llvm::Module *)> mlir::makeLLVMPassesTransformer(
}
if (insertOptPasses)
populatePassManagers(modulePM, funcPM, mbOptLevel.getValue(), 0);
populatePassManagers(modulePM, funcPM, mbOptLevel.getValue(), 0,
targetMachine);
runPasses(modulePM, funcPM, *m);
return llvm::Error::success();

View File

@ -308,8 +308,8 @@ int mlir::JitRunnerMain(
if (failed(mlirTransformer(m.get())))
return EXIT_FAILURE;
auto transformer =
mlir::makeLLVMPassesTransformer(passes, optLevel, optPosition);
auto transformer = mlir::makeLLVMPassesTransformer(
passes, optLevel, /*targetMachine=*/nullptr, optPosition);
auto error = mainFuncType.getValue() == "f32"
? compileAndExecuteSingleFloatReturnFunction(
m.get(), mainFuncName.getValue(), transformer)