Generalize the pass hierarchy by adding a general OpPass<PassT, OpT>.

This pass class generalizes the current functionality between FunctionPass and ModulePass, and allows for operating on any operation type. The pass manager currently only supports OpPasses operating on FuncOp and ModuleOp, but this restriction will be relaxed in follow-up changes. A utility class OpPassBase<OpT> allows for generically referring to operation specific passes: e.g. FunctionPassBase == OpPassBase<FuncOp>.

PiperOrigin-RevId: 266442239
This commit is contained in:
River Riddle 2019-08-30 13:16:13 -07:00 committed by TensorFlower Gardener
parent 93b86cc5ee
commit 222cdccfa6
19 changed files with 226 additions and 214 deletions

View File

@ -21,8 +21,12 @@ limitations under the License.
#include "llvm/ADT/ArrayRef.h"
namespace mlir {
class FunctionPassBase;
class ModulePassBase;
class FuncOp;
class ModuleOp;
template <typename T>
class OpPassBase;
using FunctionPassBase = OpPassBase<FuncOp>;
using ModulePassBase = OpPassBase<ModuleOp>;
namespace TFL {

View File

@ -19,7 +19,10 @@ limitations under the License.
#include <memory>
namespace mlir {
class FunctionPassBase;
class FuncOp;
template <typename T>
class OpPassBase;
using FunctionPassBase = OpPassBase<FuncOp>;
namespace xla_hlo {

View File

@ -27,7 +27,9 @@
namespace mlir {
class FunctionPassBase;
class FuncOp;
template <typename T> class OpPassBase;
using FunctionPassBase = OpPassBase<FuncOp>;
/// Creates a pass to check memref accesses in a Function.
FunctionPassBase *createMemRefBoundCheckPass();

View File

@ -23,9 +23,10 @@
namespace mlir {
class FuncOp;
class FunctionPassBase;
struct LogicalResult;
class MLIRContext;
template <typename T> class OpPassBase;
using FunctionPassBase = OpPassBase<FuncOp>;
class RewritePattern;
// Owning list of rewriting patterns.

View File

@ -27,7 +27,7 @@ namespace mlir {
class FuncOp;
class Location;
class ModulePassBase;
class ModuleOp;
class OpBuilder;
class Value;
@ -35,6 +35,9 @@ namespace LLVM {
class LLVMDialect;
}
template <typename T> class OpPassBase;
using ModulePassBase = OpPassBase<ModuleOp>;
using OwnedCubin = std::unique_ptr<std::vector<char>>;
using CubinGenerator = std::function<OwnedCubin(const std::string &, FuncOp &)>;

View File

@ -21,9 +21,12 @@
namespace mlir {
class LLVMTypeConverter;
class ModulePassBase;
class OwningRewritePatternList;
class ModuleOp;
template <typename OpT> class OpPassBase;
using ModulePassBase = OpPassBase<ModuleOp>;
/// Collect a set of patterns to convert from the GPU dialect to NVVM.
void populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
OwningRewritePatternList &patterns);

View File

@ -20,7 +20,9 @@
#include <memory>
namespace mlir {
class FunctionPassBase;
class FuncOp;
template <typename T> class OpPassBase;
using FunctionPassBase = OpPassBase<FuncOp>;
/// Create a pass that converts loop nests into GPU kernels. It considers
/// top-level affine.for and linalg.for operations as roots of loop nests and

View File

@ -33,7 +33,8 @@ class LLVMTypeConverter;
struct LogicalResult;
class MLIRContext;
class ModuleOp;
class ModulePassBase;
template <typename T> class OpPassBase;
using ModulePassBase = OpPassBase<ModuleOp>;
class RewritePattern;
class Type;

View File

@ -19,9 +19,12 @@
namespace mlir {
class LLVMTypeConverter;
class ModulePassBase;
class ModuleOp;
class OwningRewritePatternList;
template <typename T> class OpPassBase;
using ModulePassBase = OpPassBase<ModuleOp>;
/// Collect a set of patterns to convert from the Vector dialect to LLVM.
void populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter,
OwningRewritePatternList &patterns);

View File

@ -23,7 +23,9 @@
#define MLIR_DIALECT_FXPMATHOPS_PASSES_H
namespace mlir {
class FunctionPassBase;
class FuncOp;
template <typename T> class OpPassBase;
using FunctionPassBase = OpPassBase<FuncOp>;
namespace fxpmath {

View File

@ -26,7 +26,9 @@
namespace mlir {
class ModulePassBase;
class ModuleOp;
template <typename T> class OpPassBase;
using ModulePassBase = OpPassBase<ModuleOp>;
std::unique_ptr<ModulePassBase> createGpuKernelOutliningPass();

View File

@ -26,8 +26,11 @@
#include "llvm/ADT/ArrayRef.h"
namespace mlir {
class FunctionPassBase;
class ModulePassBase;
class FuncOp;
class ModuleOp;
template <typename T> class OpPassBase;
using FunctionPassBase = OpPassBase<FuncOp>;
using ModulePassBase = OpPassBase<ModuleOp>;
namespace linalg {
std::unique_ptr<FunctionPassBase>

View File

@ -28,7 +28,9 @@
#include <memory>
namespace mlir {
class FunctionPassBase;
class FuncOp;
template <typename T> class OpPassBase;
using FunctionPassBase = OpPassBase<FuncOp>;
namespace quant {

View File

@ -25,16 +25,36 @@
#include "llvm/ADT/PointerIntPair.h"
namespace mlir {
namespace detail {
class FunctionPassExecutor;
class ModulePassExecutor;
/// The state for a single execution of a pass. This provides a unified
/// interface for accessing and initializing necessary state for pass execution.
struct PassExecutionState {
PassExecutionState(Operation *ir, AnalysisManager analysisManager)
: irAndPassFailed(ir, false), analysisManager(analysisManager) {}
/// The current operation being transformed and a bool for if the pass
/// signaled a failure.
llvm::PointerIntPair<Operation *, 1, bool> irAndPassFailed;
/// The analysis manager for the operation.
AnalysisManager analysisManager;
/// The set of preserved analyses for the current execution.
detail::PreservedAnalyses preservedAnalyses;
};
} // namespace detail
/// The abstract base pass class. This class contains information describing the
/// derived pass object, e.g its kind and abstract PassInfo.
class Pass {
public:
enum class Kind { FunctionPass, ModulePass };
virtual ~Pass() = default;
/// Returns the unique identifier that corresponds to this pass.
const PassID *getPassID() const { return passIDAndKind.getPointer(); }
const PassID *getPassID() const { return passID; }
/// Returns the pass info for the specified pass class or null if unknown.
static const PassInfo *lookupPassInfo(const PassID *passID);
@ -45,137 +65,61 @@ public:
/// Returns the pass info for this pass.
const PassInfo *lookupPassInfo() const { return lookupPassInfo(getPassID()); }
/// Return the kind of this pass.
Kind getKind() const { return passIDAndKind.getInt(); }
/// Returns the derived pass name.
virtual StringRef getName() = 0;
/// Returns the name of the operation that this pass operates on.
StringRef getOpName() const { return opName; }
protected:
Pass(const PassID *passID, Kind kind) : passIDAndKind(passID, kind) {}
Pass(const PassID *passID, StringRef opName)
: passID(passID), opName(opName) {}
/// Returns the current pass state.
detail::PassExecutionState &getPassState() {
assert(passState && "pass state was never initialized");
return *passState;
}
/// Return the MLIR context for the current function being transformed.
MLIRContext &getContext() { return *getOperation()->getContext(); }
/// The polymorphic API that runs the pass over the currently held operation.
virtual void runOnOperation() = 0;
/// A clone method to create a copy of this pass.
virtual std::unique_ptr<Pass> clone() const = 0;
/// Return the current operation being transformed.
Operation *getOperation() {
return getPassState().irAndPassFailed.getPointer();
}
/// Returns the current analysis manager.
AnalysisManager getAnalysisManager() {
return getPassState().analysisManager;
}
private:
/// Forwarding function to execute this pass on the given operation.
LLVM_NODISCARD
LogicalResult run(Operation *op, AnalysisManager am);
/// Out of line virtual method to ensure vtables and metadata are emitted to a
/// single .o file.
virtual void anchor();
/// Represents a unique identifier for the pass and its kind.
llvm::PointerIntPair<const PassID *, 1, Kind> passIDAndKind;
};
/// Represents a unique identifier for the pass.
const PassID *passID;
namespace detail {
class FunctionPassExecutor;
class ModulePassExecutor;
/// The state for a single execution of a pass. This provides a unified
/// interface for accessing and initializing necessary state for pass execution.
template <typename IRUnitT> struct PassExecutionState {
PassExecutionState(IRUnitT ir, AnalysisManager analysisManager)
: irAndPassFailed(ir, false), analysisManager(analysisManager) {}
/// The current IR unit being transformed and a bool for if the pass signaled
/// a failure.
llvm::PointerIntPair<IRUnitT, 1, bool> irAndPassFailed;
/// The analysis manager for the IR unit.
AnalysisManager analysisManager;
/// The set of preserved analyses for the current execution.
detail::PreservedAnalyses preservedAnalyses;
};
} // namespace detail
/// Pass to transform a specific function within a module. Derived passes should
/// not inherit from this class directly, and instead should use the CRTP
/// FunctionPass class.
class FunctionPassBase : public Pass {
using PassStateT = detail::PassExecutionState<FuncOp>;
public:
static bool classof(const Pass *pass) {
return pass->getKind() == Kind::FunctionPass;
}
protected:
explicit FunctionPassBase(const PassID *id) : Pass(id, Kind::FunctionPass) {}
/// The polymorphic API that runs the pass over the currently held function.
virtual void runOnFunction() = 0;
/// A clone method to create a copy of this pass.
virtual std::unique_ptr<FunctionPassBase> clone() const = 0;
/// Return the current function being transformed.
FuncOp getFunction() { return getPassState().irAndPassFailed.getPointer(); }
/// Return the MLIR context for the current function being transformed.
MLIRContext &getContext() { return *getFunction().getContext(); }
/// Returns the current pass state.
PassStateT &getPassState() {
assert(passState && "pass state was never initialized");
return *passState;
}
/// Returns the current analysis manager.
AnalysisManager getAnalysisManager() {
return getPassState().analysisManager;
}
private:
/// Forwarding function to execute this pass.
LLVM_NODISCARD
LogicalResult run(FuncOp fn, AnalysisManager am);
/// The name of the operation that this pass operates on.
StringRef opName;
/// The current execution state for the pass.
llvm::Optional<PassStateT> passState;
llvm::Optional<detail::PassExecutionState> passState;
/// Allow access to 'run'.
/// Allow access to 'clone' and 'run'.
friend detail::FunctionPassExecutor;
};
/// Pass to transform a module. Derived passes should not inherit from this
/// class directly, and instead should use the CRTP ModulePass class.
class ModulePassBase : public Pass {
using PassStateT = detail::PassExecutionState<ModuleOp>;
public:
static bool classof(const Pass *pass) {
return pass->getKind() == Kind::ModulePass;
}
protected:
explicit ModulePassBase(const PassID *id) : Pass(id, Kind::ModulePass) {}
/// The polymorphic API that runs the pass over the currently held module.
virtual void runOnModule() = 0;
/// Return the current module being transformed.
ModuleOp getModule() { return getPassState().irAndPassFailed.getPointer(); }
/// Return the MLIR context for the current module being transformed.
MLIRContext &getContext() { return *getModule().getContext(); }
/// Returns the current pass state.
PassStateT &getPassState() {
assert(passState && "pass state was never initialized");
return *passState;
}
/// Returns the current analysis manager.
AnalysisManager getAnalysisManager() {
return getPassState().analysisManager;
}
private:
/// Forwarding function to execute this pass.
LLVM_NODISCARD
LogicalResult run(ModuleOp module, AnalysisManager am);
/// The current execution state for the pass.
llvm::Optional<PassStateT> passState;
/// Allow access to 'run'.
friend detail::ModulePassExecutor;
};
@ -185,7 +129,7 @@ private:
namespace detail {
/// The opaque CRTP model of a pass. This class provides utilities for derived
/// pass execution and handles all of the necessary polymorphic API.
template <typename IRUnitT, typename PassT, typename BasePassT>
template <typename PassT, typename BasePassT>
class PassModel : public BasePassT {
public:
/// Support isa/dyn_cast functionality for the derived pass class.
@ -194,7 +138,7 @@ public:
}
protected:
PassModel() : BasePassT(PassID::getID<PassT>()) {}
PassModel(StringRef opName) : BasePassT(PassID::getID<PassT>(), opName) {}
/// Signal that some invariant was broken when running. The IR is allowed to
/// be in an invalid state.
@ -234,9 +178,75 @@ protected:
name.consume_front("(anonymous namespace)::");
return name;
}
/// A clone method to create a copy of this pass.
std::unique_ptr<Pass> clone() const override {
return std::make_unique<PassT>(*static_cast<const PassT *>(this));
}
/// Returns the analysis for the parent operation if it exists.
template <typename AnalysisT>
llvm::Optional<std::reference_wrapper<AnalysisT>>
getCachedParentAnalysis(Operation *parent) {
return this->getAnalysisManager()
.template getCachedParentAnalysis<AnalysisT>(parent);
}
template <typename AnalysisT>
llvm::Optional<std::reference_wrapper<AnalysisT>> getCachedParentAnalysis() {
return this->getAnalysisManager()
.template getCachedParentAnalysis<AnalysisT>(
this->getOperation()->getParentOp());
}
/// Returns the analysis for the given child operation if it exists.
template <typename AnalysisT>
llvm::Optional<std::reference_wrapper<AnalysisT>>
getCachedChildAnalysis(Operation *child) {
return this->getAnalysisManager()
.template getCachedChildAnalysis<AnalysisT>(child);
}
/// Returns the analysis for the given child operation, or creates it if it
/// doesn't exist.
template <typename AnalysisT> AnalysisT &getChildAnalysis(Operation *child) {
return this->getAnalysisManager().template getChildAnalysis<AnalysisT>(
child);
}
};
} // end namespace detail
/// Utility base class for OpPass below to denote an opaque pass operating on a
/// specific operation type.
template <typename OpT> class OpPassBase : public Pass {
public:
using Pass::Pass;
/// Support isa/dyn_cast functionality.
static bool classof(const Pass *pass) {
return pass->getOpName() == OpT::getOperationName();
}
};
/// Pass to transform an operation of a specific type.
///
/// Operation passes must not:
/// - read or modify any other operations within the parent region, as
/// other threads may be manipulating them concurrently.
/// - modify any state within the parent operation, this includes adding
/// additional operations.
///
/// Derived function passes are expected to provide the following:
/// - A 'void runOnOperation()' method.
template <typename OpT, typename PassT>
class OpPass : public detail::PassModel<PassT, OpPassBase<OpT>> {
protected:
OpPass()
: detail::PassModel<PassT, OpPassBase<OpT>>(OpT::getOperationName()) {}
/// Return the current operation being transformed.
OpT getOperation() { return cast<OpT>(Pass::getOperation()); }
};
/// A model for providing function pass specific utilities.
///
/// Function passes must not:
@ -247,41 +257,39 @@ protected:
///
/// Derived function passes are expected to provide the following:
/// - A 'void runOnFunction()' method.
template <typename T>
struct FunctionPass : public detail::PassModel<FuncOp, T, FunctionPassBase> {
/// Returns the analysis for the parent module if it exists.
template <typename AnalysisT>
llvm::Optional<std::reference_wrapper<AnalysisT>> getCachedModuleAnalysis() {
return this->getAnalysisManager()
.template getCachedParentAnalysis<AnalysisT>(
this->getFunction().getParentOp());
template <typename T> struct FunctionPass : public OpPass<FuncOp, T> {
/// The polymorphic API that runs the pass over the currently held function.
virtual void runOnFunction() = 0;
/// The polymorphic API that runs the pass over the currently held operation.
void runOnOperation() final {
if (!getFunction().isExternal())
runOnFunction();
}
/// A clone method to create a copy of this pass.
std::unique_ptr<FunctionPassBase> clone() const override {
return std::make_unique<T>(*static_cast<const T *>(this));
}
/// Return the current module being transformed.
FuncOp getFunction() { return this->getOperation(); }
};
/// A model for providing module pass specific utilities.
///
/// Derived module passes are expected to provide the following:
/// - A 'void runOnModule()' method.
template <typename T>
struct ModulePass : public detail::PassModel<ModuleOp, T, ModulePassBase> {
/// Returns the analysis for a child function.
template <typename AnalysisT> AnalysisT &getFunctionAnalysis(FuncOp f) {
return this->getAnalysisManager().template getChildAnalysis<AnalysisT>(f);
}
template <typename T> struct ModulePass : public OpPass<ModuleOp, T> {
/// The polymorphic API that runs the pass over the currently held module.
virtual void runOnModule() = 0;
/// Returns an existing analysis for a child function if it exists.
template <typename AnalysisT>
llvm::Optional<std::reference_wrapper<AnalysisT>>
getCachedFunctionAnalysis(FuncOp f) {
return this->getAnalysisManager()
.template getCachedChildAnalysis<AnalysisT>(f);
}
/// The polymorphic API that runs the pass over the currently held operation.
void runOnOperation() final { runOnModule(); }
/// Return the current module being transformed.
ModuleOp getModule() { return this->getOperation(); }
};
/// Using directives defining legacy base classes.
// TODO(riverriddle) These should be removed in favor of OpPassBase<T>.
using FunctionPassBase = OpPassBase<FuncOp>;
using ModulePassBase = OpPassBase<ModuleOp>;
} // end namespace mlir
#endif // MLIR_PASS_PASS_H

View File

@ -26,9 +26,12 @@ class Any;
} // end namespace llvm
namespace mlir {
class FunctionPassBase;
class FuncOp;
class ModuleOp;
class ModulePassBase;
template <typename OpT> class OpPassBase;
using FunctionPassBase = OpPassBase<FuncOp>;
using ModulePassBase = OpPassBase<ModuleOp>;
class Pass;
class PassInstrumentation;
class PassInstrumentor;

View File

@ -30,8 +30,11 @@
namespace mlir {
class AffineForOp;
class FunctionPassBase;
class ModulePassBase;
class FuncOp;
class ModuleOp;
template <typename T> class OpPassBase;
using FunctionPassBase = OpPassBase<FuncOp>;
using ModulePassBase = OpPassBase<ModuleOp>;
/// Creates a constant folding pass. Note that this pass solely provides simple
/// top-down constant folding functionality; it is intended to be used for

View File

@ -27,7 +27,9 @@
#include "llvm/Support/raw_ostream.h"
namespace mlir {
class FunctionPassBase;
class FuncOp;
template <typename T> class OpPassBase;
using FunctionPassBase = OpPassBase<FuncOp>;
class Region;
/// Displays the CFG in a window. This is for use from the debugger and

View File

@ -42,17 +42,16 @@ using namespace mlir::detail;
void Pass::anchor() {}
/// Forwarding function to execute this pass.
LogicalResult FunctionPassBase::run(FuncOp fn, AnalysisManager am) {
// Initialize the pass state.
passState.emplace(fn, am);
LogicalResult Pass::run(Operation *op, AnalysisManager am) {
passState.emplace(op, am);
// Instrument before the pass has run.
auto pi = am.getPassInstrumentor();
if (pi)
pi->runBeforePass(this, fn);
pi->runBeforePass(this, op);
// Invoke the virtual runOnFunction function.
runOnFunction();
// Invoke the virtual runOnOperation method.
runOnOperation();
// Invalidate any non preserved analyses.
am.invalidate(passState->preservedAnalyses);
@ -61,38 +60,9 @@ LogicalResult FunctionPassBase::run(FuncOp fn, AnalysisManager am) {
bool passFailed = passState->irAndPassFailed.getInt();
if (pi) {
if (passFailed)
pi->runAfterPassFailed(this, fn);
pi->runAfterPassFailed(this, op);
else
pi->runAfterPass(this, fn);
}
// Return if the pass signaled a failure.
return failure(passFailed);
}
/// Forwarding function to execute this pass.
LogicalResult ModulePassBase::run(ModuleOp module, AnalysisManager am) {
// Initialize the pass state.
passState.emplace(module, am);
// Instrument before the pass has run.
auto pi = am.getPassInstrumentor();
if (pi)
pi->runBeforePass(this, module);
// Invoke the virtual runOnModule function.
runOnModule();
// Invalidate any non preserved analyses.
am.invalidate(passState->preservedAnalyses);
// Instrument after the pass has run.
bool passFailed = passState->irAndPassFailed.getInt();
if (pi) {
if (passFailed)
pi->runAfterPassFailed(this, module);
else
pi->runAfterPass(this, module);
pi->runAfterPass(this, op);
}
// Return if the pass signaled a failure.
@ -106,7 +76,7 @@ LogicalResult ModulePassBase::run(ModuleOp module, AnalysisManager am) {
FunctionPassExecutor::FunctionPassExecutor(const FunctionPassExecutor &rhs)
: PassExecutor(Kind::FunctionExecutor) {
for (auto &pass : rhs.passes)
addPass(pass->clone());
addPass(cast<FunctionPassBase>(pass->clone()));
}
/// Run all of the passes in this manager over the current function.
@ -265,14 +235,9 @@ void PassManager::disableMultithreading(bool disable) {
/// Add an opaque pass pointer to the current manager. This takes ownership
/// over the provided pass pointer.
void PassManager::addPass(std::unique_ptr<Pass> pass) {
switch (pass->getKind()) {
case Pass::Kind::FunctionPass:
addPass(cast<FunctionPassBase>(std::move(pass)));
break;
case Pass::Kind::ModulePass:
addPass(cast<ModulePassBase>(std::move(pass)));
break;
}
if (isa<FunctionPassBase>(pass.get()))
return addPass(cast<FunctionPassBase>(std::move(pass)));
addPass(cast<ModulePassBase>(std::move(pass)));
}
/// Add a module pass to the current manager. This takes ownership over the

View File

@ -66,7 +66,7 @@ public:
/// Add a pass to the current executor. This takes ownership over the provided
/// pass pointer.
void addPass(std::unique_ptr<FunctionPassBase> pass) {
void addPass(std::unique_ptr<Pass> pass) {
passes.push_back(std::move(pass));
}
@ -78,7 +78,7 @@ public:
}
private:
std::vector<std::unique_ptr<FunctionPassBase>> passes;
std::vector<std::unique_ptr<Pass>> passes;
};
/// A pass executor that contains a list of passes over a module unit.