diff --git a/tensorflow/compiler/mlir/lite/transforms/passes.h b/tensorflow/compiler/mlir/lite/transforms/passes.h index fb01ba0e9c8..e8c40537ee4 100644 --- a/tensorflow/compiler/mlir/lite/transforms/passes.h +++ b/tensorflow/compiler/mlir/lite/transforms/passes.h @@ -21,8 +21,12 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" namespace mlir { -class FunctionPassBase; -class ModulePassBase; +class FuncOp; +class ModuleOp; +template +class OpPassBase; +using FunctionPassBase = OpPassBase; +using ModulePassBase = OpPassBase; namespace TFL { diff --git a/tensorflow/compiler/mlir/xla/transforms/passes.h b/tensorflow/compiler/mlir/xla/transforms/passes.h index 3eb97dd6a0f..6f1c278862c 100644 --- a/tensorflow/compiler/mlir/xla/transforms/passes.h +++ b/tensorflow/compiler/mlir/xla/transforms/passes.h @@ -19,7 +19,10 @@ limitations under the License. #include namespace mlir { -class FunctionPassBase; +class FuncOp; +template +class OpPassBase; +using FunctionPassBase = OpPassBase; namespace xla_hlo { diff --git a/third_party/mlir/include/mlir/Analysis/Passes.h b/third_party/mlir/include/mlir/Analysis/Passes.h index 9eafcd35576..8c947e6c222 100644 --- a/third_party/mlir/include/mlir/Analysis/Passes.h +++ b/third_party/mlir/include/mlir/Analysis/Passes.h @@ -27,7 +27,9 @@ namespace mlir { -class FunctionPassBase; +class FuncOp; +template class OpPassBase; +using FunctionPassBase = OpPassBase; /// Creates a pass to check memref accesses in a Function. FunctionPassBase *createMemRefBoundCheckPass(); diff --git a/third_party/mlir/include/mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h b/third_party/mlir/include/mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h index 78e4356607f..e6bf621cd7c 100644 --- a/third_party/mlir/include/mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h +++ b/third_party/mlir/include/mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h @@ -23,9 +23,10 @@ namespace mlir { class FuncOp; -class FunctionPassBase; struct LogicalResult; class MLIRContext; +template class OpPassBase; +using FunctionPassBase = OpPassBase; class RewritePattern; // Owning list of rewriting patterns. diff --git a/third_party/mlir/include/mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h b/third_party/mlir/include/mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h index b8b7a1e37ef..8d5c5013599 100644 --- a/third_party/mlir/include/mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h +++ b/third_party/mlir/include/mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h @@ -27,7 +27,7 @@ namespace mlir { class FuncOp; class Location; -class ModulePassBase; +class ModuleOp; class OpBuilder; class Value; @@ -35,6 +35,9 @@ namespace LLVM { class LLVMDialect; } +template class OpPassBase; +using ModulePassBase = OpPassBase; + using OwnedCubin = std::unique_ptr>; using CubinGenerator = std::function; diff --git a/third_party/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h b/third_party/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h index 35f231464f1..01e50baa592 100644 --- a/third_party/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h +++ b/third_party/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h @@ -21,9 +21,12 @@ namespace mlir { class LLVMTypeConverter; -class ModulePassBase; class OwningRewritePatternList; +class ModuleOp; +template class OpPassBase; +using ModulePassBase = OpPassBase; + /// Collect a set of patterns to convert from the GPU dialect to NVVM. void populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter, OwningRewritePatternList &patterns); diff --git a/third_party/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h b/third_party/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h index 3d32c36c43c..9ef21ea97b6 100644 --- a/third_party/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h +++ b/third_party/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h @@ -20,7 +20,9 @@ #include namespace mlir { -class FunctionPassBase; +class FuncOp; +template class OpPassBase; +using FunctionPassBase = OpPassBase; /// Create a pass that converts loop nests into GPU kernels. It considers /// top-level affine.for and linalg.for operations as roots of loop nests and diff --git a/third_party/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h b/third_party/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h index d2f416b35fe..10aa8ff9628 100644 --- a/third_party/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h +++ b/third_party/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h @@ -33,7 +33,8 @@ class LLVMTypeConverter; struct LogicalResult; class MLIRContext; class ModuleOp; -class ModulePassBase; +template class OpPassBase; +using ModulePassBase = OpPassBase; class RewritePattern; class Type; diff --git a/third_party/mlir/include/mlir/Conversion/VectorToLLVM/VectorToLLVM.h b/third_party/mlir/include/mlir/Conversion/VectorToLLVM/VectorToLLVM.h index 7334c67e0d3..c781858a672 100644 --- a/third_party/mlir/include/mlir/Conversion/VectorToLLVM/VectorToLLVM.h +++ b/third_party/mlir/include/mlir/Conversion/VectorToLLVM/VectorToLLVM.h @@ -19,9 +19,12 @@ namespace mlir { class LLVMTypeConverter; -class ModulePassBase; +class ModuleOp; class OwningRewritePatternList; +template class OpPassBase; +using ModulePassBase = OpPassBase; + /// Collect a set of patterns to convert from the Vector dialect to LLVM. void populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter, OwningRewritePatternList &patterns); diff --git a/third_party/mlir/include/mlir/Dialect/FxpMathOps/Passes.h b/third_party/mlir/include/mlir/Dialect/FxpMathOps/Passes.h index 74c634a6889..f4099ab7754 100644 --- a/third_party/mlir/include/mlir/Dialect/FxpMathOps/Passes.h +++ b/third_party/mlir/include/mlir/Dialect/FxpMathOps/Passes.h @@ -23,7 +23,9 @@ #define MLIR_DIALECT_FXPMATHOPS_PASSES_H namespace mlir { -class FunctionPassBase; +class FuncOp; +template class OpPassBase; +using FunctionPassBase = OpPassBase; namespace fxpmath { diff --git a/third_party/mlir/include/mlir/Dialect/GPU/Passes.h b/third_party/mlir/include/mlir/Dialect/GPU/Passes.h index d562b5835c7..14a9f013c99 100644 --- a/third_party/mlir/include/mlir/Dialect/GPU/Passes.h +++ b/third_party/mlir/include/mlir/Dialect/GPU/Passes.h @@ -26,7 +26,9 @@ namespace mlir { -class ModulePassBase; +class ModuleOp; +template class OpPassBase; +using ModulePassBase = OpPassBase; std::unique_ptr createGpuKernelOutliningPass(); diff --git a/third_party/mlir/include/mlir/Dialect/Linalg/Passes.h b/third_party/mlir/include/mlir/Dialect/Linalg/Passes.h index e17439f6eea..118e278ef60 100644 --- a/third_party/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/third_party/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -26,8 +26,11 @@ #include "llvm/ADT/ArrayRef.h" namespace mlir { -class FunctionPassBase; -class ModulePassBase; +class FuncOp; +class ModuleOp; +template class OpPassBase; +using FunctionPassBase = OpPassBase; +using ModulePassBase = OpPassBase; namespace linalg { std::unique_ptr diff --git a/third_party/mlir/include/mlir/Dialect/QuantOps/Passes.h b/third_party/mlir/include/mlir/Dialect/QuantOps/Passes.h index 1d43f7087db..5e5fd700f92 100644 --- a/third_party/mlir/include/mlir/Dialect/QuantOps/Passes.h +++ b/third_party/mlir/include/mlir/Dialect/QuantOps/Passes.h @@ -28,7 +28,9 @@ #include namespace mlir { -class FunctionPassBase; +class FuncOp; +template class OpPassBase; +using FunctionPassBase = OpPassBase; namespace quant { diff --git a/third_party/mlir/include/mlir/Pass/Pass.h b/third_party/mlir/include/mlir/Pass/Pass.h index 360eaaff9b3..79ede0e92e0 100644 --- a/third_party/mlir/include/mlir/Pass/Pass.h +++ b/third_party/mlir/include/mlir/Pass/Pass.h @@ -25,16 +25,36 @@ #include "llvm/ADT/PointerIntPair.h" namespace mlir { +namespace detail { +class FunctionPassExecutor; +class ModulePassExecutor; + +/// The state for a single execution of a pass. This provides a unified +/// interface for accessing and initializing necessary state for pass execution. +struct PassExecutionState { + PassExecutionState(Operation *ir, AnalysisManager analysisManager) + : irAndPassFailed(ir, false), analysisManager(analysisManager) {} + + /// The current operation being transformed and a bool for if the pass + /// signaled a failure. + llvm::PointerIntPair irAndPassFailed; + + /// The analysis manager for the operation. + AnalysisManager analysisManager; + + /// The set of preserved analyses for the current execution. + detail::PreservedAnalyses preservedAnalyses; +}; +} // namespace detail + /// The abstract base pass class. This class contains information describing the /// derived pass object, e.g its kind and abstract PassInfo. class Pass { public: - enum class Kind { FunctionPass, ModulePass }; - virtual ~Pass() = default; /// Returns the unique identifier that corresponds to this pass. - const PassID *getPassID() const { return passIDAndKind.getPointer(); } + const PassID *getPassID() const { return passID; } /// Returns the pass info for the specified pass class or null if unknown. static const PassInfo *lookupPassInfo(const PassID *passID); @@ -45,137 +65,61 @@ public: /// Returns the pass info for this pass. const PassInfo *lookupPassInfo() const { return lookupPassInfo(getPassID()); } - /// Return the kind of this pass. - Kind getKind() const { return passIDAndKind.getInt(); } - /// Returns the derived pass name. virtual StringRef getName() = 0; + /// Returns the name of the operation that this pass operates on. + StringRef getOpName() const { return opName; } + protected: - Pass(const PassID *passID, Kind kind) : passIDAndKind(passID, kind) {} + Pass(const PassID *passID, StringRef opName) + : passID(passID), opName(opName) {} + + /// Returns the current pass state. + detail::PassExecutionState &getPassState() { + assert(passState && "pass state was never initialized"); + return *passState; + } + + /// Return the MLIR context for the current function being transformed. + MLIRContext &getContext() { return *getOperation()->getContext(); } + + /// The polymorphic API that runs the pass over the currently held operation. + virtual void runOnOperation() = 0; + + /// A clone method to create a copy of this pass. + virtual std::unique_ptr clone() const = 0; + + /// Return the current operation being transformed. + Operation *getOperation() { + return getPassState().irAndPassFailed.getPointer(); + } + + /// Returns the current analysis manager. + AnalysisManager getAnalysisManager() { + return getPassState().analysisManager; + } private: + /// Forwarding function to execute this pass on the given operation. + LLVM_NODISCARD + LogicalResult run(Operation *op, AnalysisManager am); + /// Out of line virtual method to ensure vtables and metadata are emitted to a /// single .o file. virtual void anchor(); - /// Represents a unique identifier for the pass and its kind. - llvm::PointerIntPair passIDAndKind; -}; + /// Represents a unique identifier for the pass. + const PassID *passID; -namespace detail { -class FunctionPassExecutor; -class ModulePassExecutor; - -/// The state for a single execution of a pass. This provides a unified -/// interface for accessing and initializing necessary state for pass execution. -template struct PassExecutionState { - PassExecutionState(IRUnitT ir, AnalysisManager analysisManager) - : irAndPassFailed(ir, false), analysisManager(analysisManager) {} - - /// The current IR unit being transformed and a bool for if the pass signaled - /// a failure. - llvm::PointerIntPair irAndPassFailed; - - /// The analysis manager for the IR unit. - AnalysisManager analysisManager; - - /// The set of preserved analyses for the current execution. - detail::PreservedAnalyses preservedAnalyses; -}; -} // namespace detail - -/// Pass to transform a specific function within a module. Derived passes should -/// not inherit from this class directly, and instead should use the CRTP -/// FunctionPass class. -class FunctionPassBase : public Pass { - using PassStateT = detail::PassExecutionState; - -public: - static bool classof(const Pass *pass) { - return pass->getKind() == Kind::FunctionPass; - } - -protected: - explicit FunctionPassBase(const PassID *id) : Pass(id, Kind::FunctionPass) {} - - /// The polymorphic API that runs the pass over the currently held function. - virtual void runOnFunction() = 0; - - /// A clone method to create a copy of this pass. - virtual std::unique_ptr clone() const = 0; - - /// Return the current function being transformed. - FuncOp getFunction() { return getPassState().irAndPassFailed.getPointer(); } - - /// Return the MLIR context for the current function being transformed. - MLIRContext &getContext() { return *getFunction().getContext(); } - - /// Returns the current pass state. - PassStateT &getPassState() { - assert(passState && "pass state was never initialized"); - return *passState; - } - - /// Returns the current analysis manager. - AnalysisManager getAnalysisManager() { - return getPassState().analysisManager; - } - -private: - /// Forwarding function to execute this pass. - LLVM_NODISCARD - LogicalResult run(FuncOp fn, AnalysisManager am); + /// The name of the operation that this pass operates on. + StringRef opName; /// The current execution state for the pass. - llvm::Optional passState; + llvm::Optional passState; - /// Allow access to 'run'. + /// Allow access to 'clone' and 'run'. friend detail::FunctionPassExecutor; -}; - -/// Pass to transform a module. Derived passes should not inherit from this -/// class directly, and instead should use the CRTP ModulePass class. -class ModulePassBase : public Pass { - using PassStateT = detail::PassExecutionState; - -public: - static bool classof(const Pass *pass) { - return pass->getKind() == Kind::ModulePass; - } - -protected: - explicit ModulePassBase(const PassID *id) : Pass(id, Kind::ModulePass) {} - - /// The polymorphic API that runs the pass over the currently held module. - virtual void runOnModule() = 0; - - /// Return the current module being transformed. - ModuleOp getModule() { return getPassState().irAndPassFailed.getPointer(); } - - /// Return the MLIR context for the current module being transformed. - MLIRContext &getContext() { return *getModule().getContext(); } - - /// Returns the current pass state. - PassStateT &getPassState() { - assert(passState && "pass state was never initialized"); - return *passState; - } - - /// Returns the current analysis manager. - AnalysisManager getAnalysisManager() { - return getPassState().analysisManager; - } - -private: - /// Forwarding function to execute this pass. - LLVM_NODISCARD - LogicalResult run(ModuleOp module, AnalysisManager am); - - /// The current execution state for the pass. - llvm::Optional passState; - - /// Allow access to 'run'. friend detail::ModulePassExecutor; }; @@ -185,7 +129,7 @@ private: namespace detail { /// The opaque CRTP model of a pass. This class provides utilities for derived /// pass execution and handles all of the necessary polymorphic API. -template +template class PassModel : public BasePassT { public: /// Support isa/dyn_cast functionality for the derived pass class. @@ -194,7 +138,7 @@ public: } protected: - PassModel() : BasePassT(PassID::getID()) {} + PassModel(StringRef opName) : BasePassT(PassID::getID(), opName) {} /// Signal that some invariant was broken when running. The IR is allowed to /// be in an invalid state. @@ -234,9 +178,75 @@ protected: name.consume_front("(anonymous namespace)::"); return name; } + + /// A clone method to create a copy of this pass. + std::unique_ptr clone() const override { + return std::make_unique(*static_cast(this)); + } + + /// Returns the analysis for the parent operation if it exists. + template + llvm::Optional> + getCachedParentAnalysis(Operation *parent) { + return this->getAnalysisManager() + .template getCachedParentAnalysis(parent); + } + template + llvm::Optional> getCachedParentAnalysis() { + return this->getAnalysisManager() + .template getCachedParentAnalysis( + this->getOperation()->getParentOp()); + } + + /// Returns the analysis for the given child operation if it exists. + template + llvm::Optional> + getCachedChildAnalysis(Operation *child) { + return this->getAnalysisManager() + .template getCachedChildAnalysis(child); + } + + /// Returns the analysis for the given child operation, or creates it if it + /// doesn't exist. + template AnalysisT &getChildAnalysis(Operation *child) { + return this->getAnalysisManager().template getChildAnalysis( + child); + } }; } // end namespace detail +/// Utility base class for OpPass below to denote an opaque pass operating on a +/// specific operation type. +template class OpPassBase : public Pass { +public: + using Pass::Pass; + + /// Support isa/dyn_cast functionality. + static bool classof(const Pass *pass) { + return pass->getOpName() == OpT::getOperationName(); + } +}; + +/// Pass to transform an operation of a specific type. +/// +/// Operation passes must not: +/// - read or modify any other operations within the parent region, as +/// other threads may be manipulating them concurrently. +/// - modify any state within the parent operation, this includes adding +/// additional operations. +/// +/// Derived function passes are expected to provide the following: +/// - A 'void runOnOperation()' method. +template +class OpPass : public detail::PassModel> { +protected: + OpPass() + : detail::PassModel>(OpT::getOperationName()) {} + + /// Return the current operation being transformed. + OpT getOperation() { return cast(Pass::getOperation()); } +}; + /// A model for providing function pass specific utilities. /// /// Function passes must not: @@ -247,41 +257,39 @@ protected: /// /// Derived function passes are expected to provide the following: /// - A 'void runOnFunction()' method. -template -struct FunctionPass : public detail::PassModel { - /// Returns the analysis for the parent module if it exists. - template - llvm::Optional> getCachedModuleAnalysis() { - return this->getAnalysisManager() - .template getCachedParentAnalysis( - this->getFunction().getParentOp()); +template struct FunctionPass : public OpPass { + /// The polymorphic API that runs the pass over the currently held function. + virtual void runOnFunction() = 0; + + /// The polymorphic API that runs the pass over the currently held operation. + void runOnOperation() final { + if (!getFunction().isExternal()) + runOnFunction(); } - /// A clone method to create a copy of this pass. - std::unique_ptr clone() const override { - return std::make_unique(*static_cast(this)); - } + /// Return the current module being transformed. + FuncOp getFunction() { return this->getOperation(); } }; /// A model for providing module pass specific utilities. /// /// Derived module passes are expected to provide the following: /// - A 'void runOnModule()' method. -template -struct ModulePass : public detail::PassModel { - /// Returns the analysis for a child function. - template AnalysisT &getFunctionAnalysis(FuncOp f) { - return this->getAnalysisManager().template getChildAnalysis(f); - } +template struct ModulePass : public OpPass { + /// The polymorphic API that runs the pass over the currently held module. + virtual void runOnModule() = 0; - /// Returns an existing analysis for a child function if it exists. - template - llvm::Optional> - getCachedFunctionAnalysis(FuncOp f) { - return this->getAnalysisManager() - .template getCachedChildAnalysis(f); - } + /// The polymorphic API that runs the pass over the currently held operation. + void runOnOperation() final { runOnModule(); } + + /// Return the current module being transformed. + ModuleOp getModule() { return this->getOperation(); } }; + +/// Using directives defining legacy base classes. +// TODO(riverriddle) These should be removed in favor of OpPassBase. +using FunctionPassBase = OpPassBase; +using ModulePassBase = OpPassBase; } // end namespace mlir #endif // MLIR_PASS_PASS_H diff --git a/third_party/mlir/include/mlir/Pass/PassManager.h b/third_party/mlir/include/mlir/Pass/PassManager.h index b01445eae4c..888d903a294 100644 --- a/third_party/mlir/include/mlir/Pass/PassManager.h +++ b/third_party/mlir/include/mlir/Pass/PassManager.h @@ -26,9 +26,12 @@ class Any; } // end namespace llvm namespace mlir { -class FunctionPassBase; +class FuncOp; class ModuleOp; -class ModulePassBase; +template class OpPassBase; +using FunctionPassBase = OpPassBase; +using ModulePassBase = OpPassBase; + class Pass; class PassInstrumentation; class PassInstrumentor; diff --git a/third_party/mlir/include/mlir/Transforms/Passes.h b/third_party/mlir/include/mlir/Transforms/Passes.h index 693c7b0ae00..dc3a213f09a 100644 --- a/third_party/mlir/include/mlir/Transforms/Passes.h +++ b/third_party/mlir/include/mlir/Transforms/Passes.h @@ -30,8 +30,11 @@ namespace mlir { class AffineForOp; -class FunctionPassBase; -class ModulePassBase; +class FuncOp; +class ModuleOp; +template class OpPassBase; +using FunctionPassBase = OpPassBase; +using ModulePassBase = OpPassBase; /// Creates a constant folding pass. Note that this pass solely provides simple /// top-down constant folding functionality; it is intended to be used for diff --git a/third_party/mlir/include/mlir/Transforms/ViewRegionGraph.h b/third_party/mlir/include/mlir/Transforms/ViewRegionGraph.h index 61da9f11f19..f54d35643eb 100644 --- a/third_party/mlir/include/mlir/Transforms/ViewRegionGraph.h +++ b/third_party/mlir/include/mlir/Transforms/ViewRegionGraph.h @@ -27,7 +27,9 @@ #include "llvm/Support/raw_ostream.h" namespace mlir { -class FunctionPassBase; +class FuncOp; +template class OpPassBase; +using FunctionPassBase = OpPassBase; class Region; /// Displays the CFG in a window. This is for use from the debugger and diff --git a/third_party/mlir/lib/Pass/Pass.cpp b/third_party/mlir/lib/Pass/Pass.cpp index 0892aa087e8..e208e2029d9 100644 --- a/third_party/mlir/lib/Pass/Pass.cpp +++ b/third_party/mlir/lib/Pass/Pass.cpp @@ -42,17 +42,16 @@ using namespace mlir::detail; void Pass::anchor() {} /// Forwarding function to execute this pass. -LogicalResult FunctionPassBase::run(FuncOp fn, AnalysisManager am) { - // Initialize the pass state. - passState.emplace(fn, am); +LogicalResult Pass::run(Operation *op, AnalysisManager am) { + passState.emplace(op, am); // Instrument before the pass has run. auto pi = am.getPassInstrumentor(); if (pi) - pi->runBeforePass(this, fn); + pi->runBeforePass(this, op); - // Invoke the virtual runOnFunction function. - runOnFunction(); + // Invoke the virtual runOnOperation method. + runOnOperation(); // Invalidate any non preserved analyses. am.invalidate(passState->preservedAnalyses); @@ -61,38 +60,9 @@ LogicalResult FunctionPassBase::run(FuncOp fn, AnalysisManager am) { bool passFailed = passState->irAndPassFailed.getInt(); if (pi) { if (passFailed) - pi->runAfterPassFailed(this, fn); + pi->runAfterPassFailed(this, op); else - pi->runAfterPass(this, fn); - } - - // Return if the pass signaled a failure. - return failure(passFailed); -} - -/// Forwarding function to execute this pass. -LogicalResult ModulePassBase::run(ModuleOp module, AnalysisManager am) { - // Initialize the pass state. - passState.emplace(module, am); - - // Instrument before the pass has run. - auto pi = am.getPassInstrumentor(); - if (pi) - pi->runBeforePass(this, module); - - // Invoke the virtual runOnModule function. - runOnModule(); - - // Invalidate any non preserved analyses. - am.invalidate(passState->preservedAnalyses); - - // Instrument after the pass has run. - bool passFailed = passState->irAndPassFailed.getInt(); - if (pi) { - if (passFailed) - pi->runAfterPassFailed(this, module); - else - pi->runAfterPass(this, module); + pi->runAfterPass(this, op); } // Return if the pass signaled a failure. @@ -106,7 +76,7 @@ LogicalResult ModulePassBase::run(ModuleOp module, AnalysisManager am) { FunctionPassExecutor::FunctionPassExecutor(const FunctionPassExecutor &rhs) : PassExecutor(Kind::FunctionExecutor) { for (auto &pass : rhs.passes) - addPass(pass->clone()); + addPass(cast(pass->clone())); } /// Run all of the passes in this manager over the current function. @@ -265,14 +235,9 @@ void PassManager::disableMultithreading(bool disable) { /// Add an opaque pass pointer to the current manager. This takes ownership /// over the provided pass pointer. void PassManager::addPass(std::unique_ptr pass) { - switch (pass->getKind()) { - case Pass::Kind::FunctionPass: - addPass(cast(std::move(pass))); - break; - case Pass::Kind::ModulePass: - addPass(cast(std::move(pass))); - break; - } + if (isa(pass.get())) + return addPass(cast(std::move(pass))); + addPass(cast(std::move(pass))); } /// Add a module pass to the current manager. This takes ownership over the diff --git a/third_party/mlir/lib/Pass/PassDetail.h b/third_party/mlir/lib/Pass/PassDetail.h index aa60cfb23ea..40e75f4cf3a 100644 --- a/third_party/mlir/lib/Pass/PassDetail.h +++ b/third_party/mlir/lib/Pass/PassDetail.h @@ -66,7 +66,7 @@ public: /// Add a pass to the current executor. This takes ownership over the provided /// pass pointer. - void addPass(std::unique_ptr pass) { + void addPass(std::unique_ptr pass) { passes.push_back(std::move(pass)); } @@ -78,7 +78,7 @@ public: } private: - std::vector> passes; + std::vector> passes; }; /// A pass executor that contains a list of passes over a module unit.