NFC: Implement OwningRewritePatternList as a class instead of a using directive.

This allows for proper forward declaration, as opposed to leaking the internal implementation via a using directive. This also allows for all pattern building to go through 'insert' methods on the OwningRewritePatternList, replacing uses of 'push_back' and 'RewriteListBuilder'.

PiperOrigin-RevId: 261816316
This commit is contained in:
River Riddle 2019-08-05 18:37:56 -07:00 committed by TensorFlower Gardener
parent 85a9058eff
commit 4add221b50
33 changed files with 129 additions and 153 deletions

View File

@ -506,7 +506,7 @@ OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results, void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) { MLIRContext *context) {
results.push_back(llvm::make_unique<RemoveAdjacentReshape>(context)); results.insert<RemoveAdjacentReshape>(context);
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -591,7 +591,7 @@ struct DropFakeQuant : public RewritePattern {
void FakeQuantOp::getCanonicalizationPatterns(OwningRewritePatternList &results, void FakeQuantOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) { MLIRContext *context) {
results.push_back(llvm::make_unique<DropFakeQuant>(context)); results.insert<DropFakeQuant>(context);
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -205,9 +205,9 @@ void LegalizeTF::runOnFunction() {
// Add the generated patterns to the list. // Add the generated patterns to the list.
populateWithGenerated(ctx, &patterns); populateWithGenerated(ctx, &patterns);
RewriteListBuilder<ConvertTFConcatOp, ConvertTFConcatV2Op, ConvertTFMatMulOp, patterns.insert<ConvertTFConcatOp, ConvertTFConcatV2Op, ConvertTFMatMulOp,
ConvertTFPackOp, ConvertTFSplitOp, ConvertTFSplitVOp, ConvertTFPackOp, ConvertTFSplitOp, ConvertTFSplitVOp,
ConvertTFUnpackOp>::build(patterns, ctx); ConvertTFUnpackOp>(ctx);
applyPatternsGreedily(func, std::move(patterns)); applyPatternsGreedily(func, std::move(patterns));
} }

View File

@ -240,8 +240,8 @@ void Optimize::runOnFunction() {
auto func = getFunction(); auto func = getFunction();
// Add the generated patterns to the list. // Add the generated patterns to the list.
TFL::populateWithGenerated(ctx, &patterns); TFL::populateWithGenerated(ctx, &patterns);
RewriteListBuilder<FuseFullyConnectedAndAdd, FuseFullyConnectedAndRelu, patterns.insert<FuseFullyConnectedAndAdd, FuseFullyConnectedAndRelu,
PadStridedSliceDims>::build(patterns, ctx); PadStridedSliceDims>(ctx);
applyPatternsGreedily(func, std::move(patterns)); applyPatternsGreedily(func, std::move(patterns));
} }

View File

@ -381,8 +381,7 @@ void PrepareTFPass::runOnFunction() {
// parameters from the TF Quant ops, thus this pattern should run with the // parameters from the TF Quant ops, thus this pattern should run with the
// first `applyPatternsGreedily` method, which would otherwise removes the // first `applyPatternsGreedily` method, which would otherwise removes the
// TF FakeQuant ops by the constant folding. // TF FakeQuant ops by the constant folding.
patterns.push_back( patterns.insert<InsertTFLQuantOpsAfterTFFakeQuantOp>(&getContext());
llvm::make_unique<InsertTFLQuantOpsAfterTFFakeQuantOp>(&getContext()));
TFL::populateWithGenerated(&getContext(), &patterns); TFL::populateWithGenerated(&getContext(), &patterns);
// TODO(karimnosseir): Split to separate pass probably after // TODO(karimnosseir): Split to separate pass probably after
// deciding on long term plan for this optimization. // deciding on long term plan for this optimization.
@ -394,9 +393,8 @@ void PrepareTFPass::runOnFunction() {
// Load the generated pattern again, so new quantization pass-through // Load the generated pattern again, so new quantization pass-through
// will be applied. // will be applied.
TFL::populateWithGenerated(&getContext(), &patterns); TFL::populateWithGenerated(&getContext(), &patterns);
patterns.push_back(llvm::make_unique<ConvertTFConv2D>(&getContext())); patterns.insert<ConvertTFConv2D, ConvertTFDepthwiseConv2dNative>(
patterns.push_back( &getContext());
llvm::make_unique<ConvertTFDepthwiseConv2dNative>(&getContext()));
applyPatternsGreedily(func, std::move(patterns)); applyPatternsGreedily(func, std::move(patterns));
} }

View File

@ -55,8 +55,8 @@ void QuantizePass::runOnFunction() {
auto func = getFunction(); auto func = getFunction();
auto* ctx = func.getContext(); auto* ctx = func.getContext();
TFL::populateWithGenerated(ctx, &patterns); TFL::populateWithGenerated(ctx, &patterns);
mlir::RewriteListBuilder<mlir::TFL::GenericFullQuantizationPattern< patterns.insert<mlir::TFL::GenericFullQuantizationPattern<
mlir::TFL::QuantizeOp, mlir::TFL::DequantizeOp>>::build(patterns, ctx); mlir::TFL::QuantizeOp, mlir::TFL::DequantizeOp>>(ctx);
applyPatternsGreedily(func, std::move(patterns)); applyPatternsGreedily(func, std::move(patterns));
} }
} // namespace } // namespace

View File

@ -106,7 +106,7 @@ namespace {
void AddOp::getCanonicalizationPatterns(OwningRewritePatternList &results, void AddOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) { MLIRContext *context) {
RewriteListBuilder<AddToAddV2>::build(results, context); results.insert<AddToAddV2>(context);
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -115,7 +115,7 @@ void AddOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
void AddV2Op::getCanonicalizationPatterns(OwningRewritePatternList &results, void AddV2Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) { MLIRContext *context) {
RewriteListBuilder<AddV2OfNegLeft, AddV2OfNegRight>::build(results, context); results.insert<AddV2OfNegLeft, AddV2OfNegRight>(context);
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -144,7 +144,7 @@ struct AssertWithTrue : public OpRewritePattern<AssertOp> {
void AssertOp::getCanonicalizationPatterns(OwningRewritePatternList &results, void AssertOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) { MLIRContext *context) {
RewriteListBuilder<AssertWithTrue>::build(results, context); results.insert<AssertWithTrue>(context);
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -153,7 +153,7 @@ void AssertOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
void BitcastOp::getCanonicalizationPatterns(OwningRewritePatternList &results, void BitcastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) { MLIRContext *context) {
RewriteListBuilder<BitcastSameType, BitcastNested>::build(results, context); results.insert<BitcastSameType, BitcastNested>(context);
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -174,7 +174,7 @@ static LogicalResult Verify(BroadcastToOp op) {
void CastOp::getCanonicalizationPatterns(OwningRewritePatternList &results, void CastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) { MLIRContext *context) {
RewriteListBuilder<CastSameType>::build(results, context); results.insert<CastSameType>(context);
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -183,7 +183,7 @@ void CastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
void ConjOp::getCanonicalizationPatterns(OwningRewritePatternList &results, void ConjOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) { MLIRContext *context) {
RewriteListBuilder<ConjNested>::build(results, context); results.insert<ConjNested>(context);
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -239,7 +239,7 @@ void ConstOp::build(Builder *builder, OperationState *result, Type type,
void DivOp::getCanonicalizationPatterns(OwningRewritePatternList &results, void DivOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) { MLIRContext *context) {
RewriteListBuilder<DivWithSqrtDivisor>::build(results, context); results.insert<DivWithSqrtDivisor>(context);
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -415,7 +415,7 @@ static LogicalResult Verify(IfOp op) {
void InvertOp::getCanonicalizationPatterns(OwningRewritePatternList &results, void InvertOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) { MLIRContext *context) {
RewriteListBuilder<InvertNested>::build(results, context); results.insert<InvertNested>(context);
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -449,7 +449,7 @@ OpFoldResult LeakyReluOp::fold(ArrayRef<Attribute> operands) {
void LogOp::getCanonicalizationPatterns(OwningRewritePatternList &results, void LogOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) { MLIRContext *context) {
RewriteListBuilder<LogOfSoftmax>::build(results, context); results.insert<LogOfSoftmax>(context);
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -458,10 +458,9 @@ void LogOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
void LogicalNotOp::getCanonicalizationPatterns( void LogicalNotOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) { OwningRewritePatternList &results, MLIRContext *context) {
RewriteListBuilder<LogicalNotNested, LogicalNotOfEqual, LogicalNotOfNotEqual, results.insert<LogicalNotNested, LogicalNotOfEqual, LogicalNotOfNotEqual,
LogicalNotOfGreater, LogicalNotOfGreaterEqual, LogicalNotOfGreater, LogicalNotOfGreaterEqual,
LogicalNotOfLess, LogicalNotOfLessEqual>::build(results, LogicalNotOfLess, LogicalNotOfLessEqual>(context);
context);
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -470,7 +469,7 @@ void LogicalNotOp::getCanonicalizationPatterns(
void NegOp::getCanonicalizationPatterns(OwningRewritePatternList &results, void NegOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) { MLIRContext *context) {
RewriteListBuilder<NegNested>::build(results, context); results.insert<NegNested>(context);
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -479,7 +478,7 @@ void NegOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
void ReciprocalOp::getCanonicalizationPatterns( void ReciprocalOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) { OwningRewritePatternList &results, MLIRContext *context) {
RewriteListBuilder<ReciprocalNested>::build(results, context); results.insert<ReciprocalNested>(context);
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -538,7 +537,7 @@ void RankOp::build(Builder *builder, OperationState *result, Value *input) {
void RealDivOp::getCanonicalizationPatterns(OwningRewritePatternList &results, void RealDivOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) { MLIRContext *context) {
RewriteListBuilder<RealDivWithSqrtDivisor>::build(results, context); results.insert<RealDivWithSqrtDivisor>(context);
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -698,7 +697,7 @@ static LogicalResult Verify(SoftmaxOp op) {
void SquareOp::getCanonicalizationPatterns(OwningRewritePatternList &results, void SquareOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) { MLIRContext *context) {
RewriteListBuilder<SquareOfSub>::build(results, context); results.insert<SquareOfSub>(context);
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -707,7 +706,7 @@ void SquareOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
void SubOp::getCanonicalizationPatterns(OwningRewritePatternList &results, void SubOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) { MLIRContext *context) {
RewriteListBuilder<SubOfNeg>::build(results, context); results.insert<SubOfNeg>(context);
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -778,7 +777,7 @@ void TransposeOp::build(Builder *builder, OperationState *result, Value *x,
void TruncateDivOp::getCanonicalizationPatterns( void TruncateDivOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) { OwningRewritePatternList &results, MLIRContext *context) {
RewriteListBuilder<TruncateDivWithSqrtDivisor>::build(results, context); results.insert<TruncateDivWithSqrtDivisor>(context);
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -867,7 +866,7 @@ static LogicalResult Verify(WhileOp op) {
void XdivyOp::getCanonicalizationPatterns(OwningRewritePatternList &results, void XdivyOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) { MLIRContext *context) {
RewriteListBuilder<XdivyWithSqrtDivisor>::build(results, context); results.insert<XdivyWithSqrtDivisor>(context);
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -133,10 +133,8 @@ void LegalizeToStandard::runOnFunction() {
auto func = getFunction(); auto func = getFunction();
mlir::XLA::populateWithGenerated(func.getContext(), &patterns); mlir::XLA::populateWithGenerated(func.getContext(), &patterns);
patterns.push_back( patterns.insert<mlir::XLA::CompareFConvert, mlir::XLA::CompareIConvert>(
llvm::make_unique<mlir::XLA::CompareFConvert>(&getContext())); &getContext());
patterns.push_back(
llvm::make_unique<mlir::XLA::CompareIConvert>(&getContext()));
applyPatternsGreedily(func, std::move(patterns)); applyPatternsGreedily(func, std::move(patterns));
} }

View File

@ -29,7 +29,7 @@ class MLIRContext;
class RewritePattern; class RewritePattern;
// Owning list of rewriting patterns. // Owning list of rewriting patterns.
using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>; class OwningRewritePatternList;
/// Collect a set of patterns to lower from loop.for, loop.if, and /// Collect a set of patterns to lower from loop.for, loop.if, and
/// loop.terminator to CFG operations within the Standard dialect, in particular /// loop.terminator to CFG operations within the Standard dialect, in particular

View File

@ -38,7 +38,7 @@ class RewritePattern;
class Type; class Type;
// Owning list of rewriting patterns. // Owning list of rewriting patterns.
using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>; class OwningRewritePatternList;
/// Type for a callback constructing the owning list of patterns for the /// Type for a callback constructing the owning list of patterns for the
/// conversion to the LLVMIR dialect. The callback is expected to append /// conversion to the LLVMIR dialect. The callback is expected to append

View File

@ -57,9 +57,7 @@ class Value;
/// either OpTy or OperandAdaptor<OpTy> seamlessly. /// either OpTy or OperandAdaptor<OpTy> seamlessly.
template <typename OpTy> using OperandAdaptor = typename OpTy::OperandAdaptor; template <typename OpTy> using OperandAdaptor = typename OpTy::OperandAdaptor;
/// This is a vector that owns the patterns inside of it. class OwningRewritePatternList;
using OwningPatternList = std::vector<std::unique_ptr<Pattern>>;
using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>;
enum class OperationProperty { enum class OperationProperty {
/// This bit is set for an operation if it is a commutative operation: that /// This bit is set for an operation if it is a commutative operation: that

View File

@ -394,8 +394,39 @@ private:
// Pattern-driven rewriters // Pattern-driven rewriters
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
/// This is a vector that owns the patterns inside of it. class OwningRewritePatternList {
using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>; using PatternListT = std::vector<std::unique_ptr<RewritePattern>>;
public:
PatternListT::iterator begin() { return patterns.begin(); }
PatternListT::iterator end() { return patterns.end(); }
PatternListT::const_iterator begin() const { return patterns.begin(); }
PatternListT::const_iterator end() const { return patterns.end(); }
//===--------------------------------------------------------------------===//
// Pattern Insertion
//===--------------------------------------------------------------------===//
void insert(RewritePattern *pattern) { patterns.emplace_back(pattern); }
/// Add an instance of each of the pattern types 'Ts' to the pattern list with
/// the given arguments.
// Note: ConstructorArg is necessary here to separate the two variadic lists.
template <typename... Ts, typename ConstructorArg,
typename... ConstructorArgs>
void insert(ConstructorArg &&arg, ConstructorArgs &&... args) {
// The following expands a call to emplace_back for each of the pattern
// types 'Ts'. This magic is necessary due to a limitation in the places
// that a parameter pack can be expanded in c++11.
// FIXME: In c++17 this can be simplified by using 'fold expressions'.
using dummy = int[];
(void)dummy{
0, (patterns.emplace_back(llvm::make_unique<Ts>(arg, args...)), 0)...};
}
private:
PatternListT patterns;
};
/// This class manages optimization and execution of a group of rewrite /// This class manages optimization and execution of a group of rewrite
/// patterns, providing an API for finding and applying, the best match against /// patterns, providing an API for finding and applying, the best match against
@ -404,7 +435,7 @@ using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>;
class RewritePatternMatcher { class RewritePatternMatcher {
public: public:
/// Create a RewritePatternMatcher with the specified set of patterns. /// Create a RewritePatternMatcher with the specified set of patterns.
explicit RewritePatternMatcher(OwningRewritePatternList &&patterns); explicit RewritePatternMatcher(OwningRewritePatternList &patterns);
/// Try to match the given operation to a pattern and rewrite it. Return /// Try to match the given operation to a pattern and rewrite it. Return
/// true if any pattern matches. /// true if any pattern matches.
@ -416,7 +447,7 @@ private:
/// The group of patterns that are matched for optimization through this /// The group of patterns that are matched for optimization through this
/// matcher. /// matcher.
OwningRewritePatternList patterns; std::vector<RewritePattern *> patterns;
}; };
/// Rewrite the regions of the specified operation, which must be isolated from /// Rewrite the regions of the specified operation, which must be isolated from
@ -427,29 +458,6 @@ private:
/// ///
bool applyPatternsGreedily(Operation *op, OwningRewritePatternList &&patterns); bool applyPatternsGreedily(Operation *op, OwningRewritePatternList &&patterns);
/// Helper class to create a list of rewrite patterns given a list of their
/// types and a list of attributes perfect-forwarded to each of the conversion
/// constructors.
template <typename Arg, typename... Args> struct RewriteListBuilder {
template <typename... ConstructorArgs>
static void build(OwningRewritePatternList &patterns,
ConstructorArgs &&... constructorArgs) {
RewriteListBuilder<Args...>::build(
patterns, std::forward<ConstructorArgs>(constructorArgs)...);
RewriteListBuilder<Arg>::build(
patterns, std::forward<ConstructorArgs>(constructorArgs)...);
}
};
// Template specialization to stop recursion.
template <typename Arg> struct RewriteListBuilder<Arg> {
template <typename... ConstructorArgs>
static void build(OwningRewritePatternList &patterns,
ConstructorArgs &&... constructorArgs) {
patterns.emplace_back(llvm::make_unique<Arg>(
std::forward<ConstructorArgs>(constructorArgs)...));
}
};
} // end namespace mlir } // end namespace mlir
#endif // MLIR_PATTERN_MATCH_H #endif // MLIR_PATTERN_MATCH_H

View File

@ -32,7 +32,7 @@ class RewritePattern;
class Value; class Value;
// Owning list of rewriting patterns. // Owning list of rewriting patterns.
using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>; class OwningRewritePatternList;
/// Emit code that computes the given affine expression using standard /// Emit code that computes the given affine expression using standard
/// arithmetic operations applied to the provided dimension and symbol values. /// arithmetic operations applied to the provided dimension and symbol values.

View File

@ -708,7 +708,7 @@ struct SimplifyAffineApply : public OpRewritePattern<AffineApplyOp> {
void AffineApplyOp::getCanonicalizationPatterns( void AffineApplyOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) { OwningRewritePatternList &results, MLIRContext *context) {
results.push_back(llvm::make_unique<SimplifyAffineApply>(context)); results.insert<SimplifyAffineApply>(context);
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -912,8 +912,7 @@ LogicalResult AffineDmaStartOp::verify() {
void AffineDmaStartOp::getCanonicalizationPatterns( void AffineDmaStartOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) { OwningRewritePatternList &results, MLIRContext *context) {
/// dma_start(memrefcast) -> dma_start /// dma_start(memrefcast) -> dma_start
results.push_back( results.insert<MemRefCastFolder>(getOperationName(), context);
llvm::make_unique<MemRefCastFolder>(getOperationName(), context));
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -989,8 +988,7 @@ LogicalResult AffineDmaWaitOp::verify() {
void AffineDmaWaitOp::getCanonicalizationPatterns( void AffineDmaWaitOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) { OwningRewritePatternList &results, MLIRContext *context) {
/// dma_wait(memrefcast) -> dma_wait /// dma_wait(memrefcast) -> dma_wait
results.push_back( results.insert<MemRefCastFolder>(getOperationName(), context);
llvm::make_unique<MemRefCastFolder>(getOperationName(), context));
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1333,7 +1331,7 @@ struct AffineForLoopBoundFolder : public OpRewritePattern<AffineForOp> {
void AffineForOp::getCanonicalizationPatterns(OwningRewritePatternList &results, void AffineForOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) { MLIRContext *context) {
results.push_back(llvm::make_unique<AffineForLoopBoundFolder>(context)); results.insert<AffineForLoopBoundFolder>(context);
} }
AffineBound AffineForOp::getLowerBound() { AffineBound AffineForOp::getLowerBound() {
@ -1659,8 +1657,7 @@ LogicalResult AffineLoadOp::verify() {
void AffineLoadOp::getCanonicalizationPatterns( void AffineLoadOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) { OwningRewritePatternList &results, MLIRContext *context) {
/// load(memrefcast) -> load /// load(memrefcast) -> load
results.push_back( results.insert<MemRefCastFolder>(getOperationName(), context);
llvm::make_unique<MemRefCastFolder>(getOperationName(), context));
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1752,8 +1749,7 @@ LogicalResult AffineStoreOp::verify() {
void AffineStoreOp::getCanonicalizationPatterns( void AffineStoreOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) { OwningRewritePatternList &results, MLIRContext *context) {
/// load(memrefcast) -> load /// load(memrefcast) -> load
results.push_back( results.insert<MemRefCastFolder>(getOperationName(), context);
llvm::make_unique<MemRefCastFolder>(getOperationName(), context));
} }
#define GET_OP_CLASSES #define GET_OP_CLASSES

View File

@ -258,8 +258,7 @@ IfLowering::matchAndRewrite(IfOp ifOp, PatternRewriter &rewriter) const {
void mlir::populateLoopToStdConversionPatterns( void mlir::populateLoopToStdConversionPatterns(
OwningRewritePatternList &patterns, MLIRContext *ctx) { OwningRewritePatternList &patterns, MLIRContext *ctx) {
RewriteListBuilder<ForLowering, IfLowering, TerminatorLowering>::build( patterns.insert<ForLowering, IfLowering, TerminatorLowering>(ctx);
patterns, ctx);
} }
void ControlFlowToCFGPass::runOnFunction() { void ControlFlowToCFGPass::runOnFunction() {

View File

@ -104,8 +104,7 @@ void GPUToSPIRVPass::runOnModule() {
SPIRVTypeConverter typeConverter(context); SPIRVTypeConverter typeConverter(context);
SPIRVEntryFnTypeConverter entryFnConverter(context); SPIRVEntryFnTypeConverter entryFnConverter(context);
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
RewriteListBuilder<KernelFnConversion>::build( patterns.insert<KernelFnConversion>(context, typeConverter, entryFnConverter);
patterns, context, typeConverter, entryFnConverter);
populateStandardToSPIRVPatterns(context, patterns); populateStandardToSPIRVPatterns(context, patterns);
ConversionTarget target(*context); ConversionTarget target(*context);

View File

@ -1023,7 +1023,7 @@ void mlir::LLVM::ensureDistinctSuccessors(ModuleOp m) {
void mlir::populateStdToLLVMConversionPatterns( void mlir::populateStdToLLVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
// FIXME: this should be tablegen'ed // FIXME: this should be tablegen'ed
RewriteListBuilder< patterns.insert<
AddFOpLowering, AddIOpLowering, AndOpLowering, AllocOpLowering, AddFOpLowering, AddIOpLowering, AndOpLowering, AllocOpLowering,
BranchOpLowering, CallIndirectOpLowering, CallOpLowering, CmpIOpLowering, BranchOpLowering, CallIndirectOpLowering, CallOpLowering, CmpIOpLowering,
CondBranchOpLowering, ConstLLVMOpLowering, DeallocOpLowering, CondBranchOpLowering, ConstLLVMOpLowering, DeallocOpLowering,
@ -1032,8 +1032,7 @@ void mlir::populateStdToLLVMConversionPatterns(
MemRefCastOpLowering, MulFOpLowering, MulIOpLowering, OrOpLowering, MemRefCastOpLowering, MulFOpLowering, MulIOpLowering, OrOpLowering,
RemISOpLowering, RemIUOpLowering, RemFOpLowering, ReturnOpLowering, RemISOpLowering, RemIUOpLowering, RemFOpLowering, ReturnOpLowering,
SelectOpLowering, SIToFPLowering, StoreOpLowering, SubFOpLowering, SelectOpLowering, SIToFPLowering, StoreOpLowering, SubFOpLowering,
SubIOpLowering, XOrOpLowering>::build(patterns, *converter.getDialect(), SubIOpLowering, XOrOpLowering>(*converter.getDialect(), converter);
converter);
} }
// Convert types using the stored LLVM IR module. // Convert types using the stored LLVM IR module.

View File

@ -201,6 +201,6 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
OwningRewritePatternList &patterns) { OwningRewritePatternList &patterns) {
populateWithGenerated(context, &patterns); populateWithGenerated(context, &patterns);
// Add the return op conversion. // Add the return op conversion.
RewriteListBuilder<ReturnToSPIRVConversion>::build(patterns, context); patterns.insert<ReturnToSPIRVConversion>(context);
} }
} // namespace mlir } // namespace mlir

View File

@ -368,8 +368,7 @@ void LowerUniformRealMathPass::runOnFunction() {
auto fn = getFunction(); auto fn = getFunction();
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
auto *context = &getContext(); auto *context = &getContext();
patterns.push_back(llvm::make_unique<UniformRealAddEwPattern>(context)); patterns.insert<UniformRealAddEwPattern, UniformRealMulEwPattern>(context);
patterns.push_back(llvm::make_unique<UniformRealMulEwPattern>(context));
applyPatternsGreedily(fn, std::move(patterns)); applyPatternsGreedily(fn, std::move(patterns));
} }
@ -389,7 +388,7 @@ void LowerUniformCastsPass::runOnFunction() {
auto fn = getFunction(); auto fn = getFunction();
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
auto *context = &getContext(); auto *context = &getContext();
patterns.push_back(llvm::make_unique<UniformDequantizePattern>(context)); patterns.insert<UniformDequantizePattern>(context);
applyPatternsGreedily(fn, std::move(patterns)); applyPatternsGreedily(fn, std::move(patterns));
} }

View File

@ -372,7 +372,7 @@ class PropagateConstantBounds : public OpRewritePattern<LaunchOp> {
void LaunchOp::getCanonicalizationPatterns(OwningRewritePatternList &results, void LaunchOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) { MLIRContext *context) {
RewriteListBuilder<PropagateConstantBounds>::build(results, context); results.insert<PropagateConstantBounds>(context);
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -60,8 +60,7 @@ public:
void StorageCastOp::getCanonicalizationPatterns( void StorageCastOp::getCanonicalizationPatterns(
OwningRewritePatternList &patterns, MLIRContext *context) { OwningRewritePatternList &patterns, MLIRContext *context) {
patterns.push_back( patterns.insert<RemoveRedundantStorageCastsRewrite>(context);
llvm::make_unique<RemoveRedundantStorageCastsRewrite>(context));
} }
QuantizationDialect::QuantizationDialect(MLIRContext *context) QuantizationDialect::QuantizationDialect(MLIRContext *context)

View File

@ -108,7 +108,7 @@ void ConvertConstPass::runOnFunction() {
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
auto func = getFunction(); auto func = getFunction();
auto *context = &getContext(); auto *context = &getContext();
patterns.push_back(llvm::make_unique<QuantizedConstRewrite>(context)); patterns.insert<QuantizedConstRewrite>(context);
applyPatternsGreedily(func, std::move(patterns)); applyPatternsGreedily(func, std::move(patterns));
} }

View File

@ -97,8 +97,7 @@ void ConvertSimulatedQuantPass::runOnFunction() {
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
auto func = getFunction(); auto func = getFunction();
auto *context = &getContext(); auto *context = &getContext();
patterns.push_back( patterns.insert<ConstFakeQuantRewrite>(context, &hadFailure);
llvm::make_unique<ConstFakeQuantRewrite>(context, &hadFailure));
applyPatternsGreedily(func, std::move(patterns)); applyPatternsGreedily(func, std::move(patterns));
if (hadFailure) if (hadFailure)
signalPassFailure(); signalPassFailure();

View File

@ -149,12 +149,13 @@ void PatternRewriter::updatedRootInPlace(
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
RewritePatternMatcher::RewritePatternMatcher( RewritePatternMatcher::RewritePatternMatcher(
OwningRewritePatternList &&patterns) OwningRewritePatternList &patterns) {
: patterns(std::move(patterns)) { for (auto &pattern : patterns)
this->patterns.push_back(pattern.get());
// Sort the patterns by benefit to simplify the matching logic. // Sort the patterns by benefit to simplify the matching logic.
std::stable_sort(this->patterns.begin(), this->patterns.end(), std::stable_sort(this->patterns.begin(), this->patterns.end(),
[](const std::unique_ptr<RewritePattern> &l, [](RewritePattern *l, RewritePattern *r) {
const std::unique_ptr<RewritePattern> &r) {
return r->getBenefit() < l->getBenefit(); return r->getBenefit() < l->getBenefit();
}); });
} }
@ -162,7 +163,7 @@ RewritePatternMatcher::RewritePatternMatcher(
/// Try to match the given operation to a pattern and rewrite it. /// Try to match the given operation to a pattern and rewrite it.
bool RewritePatternMatcher::matchAndRewrite(Operation *op, bool RewritePatternMatcher::matchAndRewrite(Operation *op,
PatternRewriter &rewriter) { PatternRewriter &rewriter) {
for (auto &pattern : patterns) { for (auto *pattern : patterns) {
// Ignore patterns that are for the wrong root or are impossible to match. // Ignore patterns that are for the wrong root or are impossible to match.
if (pattern->getRootKind() != op->getName() || if (pattern->getRootKind() != op->getName() ||
pattern->getBenefit().isImpossibleToMatch()) pattern->getBenefit().isImpossibleToMatch())

View File

@ -678,12 +678,11 @@ static void
populateLinalgToLLVMConversionPatterns(LinalgTypeConverter &converter, populateLinalgToLLVMConversionPatterns(LinalgTypeConverter &converter,
OwningRewritePatternList &patterns, OwningRewritePatternList &patterns,
MLIRContext *ctx) { MLIRContext *ctx) {
RewriteListBuilder<BufferAllocOpConversion, BufferDeallocOpConversion, patterns.insert<BufferAllocOpConversion, BufferDeallocOpConversion,
BufferSizeOpConversion, DimOpConversion, BufferSizeOpConversion, DimOpConversion,
LinalgOpConversion<DotOp>, LinalgOpConversion<MatmulOp>, LinalgOpConversion<DotOp>, LinalgOpConversion<MatmulOp>,
LoadOpConversion, RangeOpConversion, SliceOpConversion, LoadOpConversion, RangeOpConversion, SliceOpConversion,
StoreOpConversion, ViewOpConversion>::build(patterns, ctx, StoreOpConversion, ViewOpConversion>(ctx, converter);
converter);
} }
namespace { namespace {

View File

@ -60,12 +60,9 @@ void RemoveInstrumentationPass::runOnFunction() {
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
auto func = getFunction(); auto func = getFunction();
auto *context = &getContext(); auto *context = &getContext();
patterns.push_back( patterns.insert<RemoveIdentityOpRewrite<StatisticsOp>,
llvm::make_unique<RemoveIdentityOpRewrite<StatisticsOp>>(context)); RemoveIdentityOpRewrite<StatisticsRefOp>,
patterns.push_back( RemoveIdentityOpRewrite<CoupledRefOp>>(context);
llvm::make_unique<RemoveIdentityOpRewrite<StatisticsRefOp>>(context));
patterns.push_back(
llvm::make_unique<RemoveIdentityOpRewrite<CoupledRefOp>>(context));
applyPatternsGreedily(func, std::move(patterns)); applyPatternsGreedily(func, std::move(patterns));
} }

View File

@ -365,8 +365,7 @@ struct SimplifyDeadAlloc : public OpRewritePattern<AllocOp> {
void AllocOp::getCanonicalizationPatterns(OwningRewritePatternList &results, void AllocOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) { MLIRContext *context) {
RewriteListBuilder<SimplifyAllocConst, SimplifyDeadAlloc>::build(results, results.insert<SimplifyAllocConst, SimplifyDeadAlloc>(context);
context);
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -544,8 +543,7 @@ static LogicalResult verify(CallIndirectOp op) {
void CallIndirectOp::getCanonicalizationPatterns( void CallIndirectOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) { OwningRewritePatternList &results, MLIRContext *context) {
results.push_back( results.insert<SimplifyIndirectCallWithKnownCallee>(context);
llvm::make_unique<SimplifyIndirectCallWithKnownCallee>(context));
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1015,7 +1013,7 @@ static void print(OpAsmPrinter *p, CondBranchOp op) {
void CondBranchOp::getCanonicalizationPatterns( void CondBranchOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) { OwningRewritePatternList &results, MLIRContext *context) {
results.push_back(llvm::make_unique<SimplifyConstCondBranchPred>(context)); results.insert<SimplifyConstCondBranchPred>(context);
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1231,9 +1229,8 @@ static LogicalResult verify(DeallocOp op) {
void DeallocOp::getCanonicalizationPatterns(OwningRewritePatternList &results, void DeallocOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) { MLIRContext *context) {
/// dealloc(memrefcast) -> dealloc /// dealloc(memrefcast) -> dealloc
results.push_back( results.insert<MemRefCastFolder>(getOperationName(), context);
llvm::make_unique<MemRefCastFolder>(getOperationName(), context)); results.insert<SimplifyDeadDealloc>(context);
results.push_back(llvm::make_unique<SimplifyDeadDealloc>(context));
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1497,8 +1494,7 @@ LogicalResult DmaStartOp::verify() {
void DmaStartOp::getCanonicalizationPatterns(OwningRewritePatternList &results, void DmaStartOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) { MLIRContext *context) {
/// dma_start(memrefcast) -> dma_start /// dma_start(memrefcast) -> dma_start
results.push_back( results.insert<MemRefCastFolder>(getOperationName(), context);
llvm::make_unique<MemRefCastFolder>(getOperationName(), context));
} }
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
@ -1561,8 +1557,7 @@ ParseResult DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) {
void DmaWaitOp::getCanonicalizationPatterns(OwningRewritePatternList &results, void DmaWaitOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) { MLIRContext *context) {
/// dma_wait(memrefcast) -> dma_wait /// dma_wait(memrefcast) -> dma_wait
results.push_back( results.insert<MemRefCastFolder>(getOperationName(), context);
llvm::make_unique<MemRefCastFolder>(getOperationName(), context));
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1695,8 +1690,7 @@ static LogicalResult verify(LoadOp op) {
void LoadOp::getCanonicalizationPatterns(OwningRewritePatternList &results, void LoadOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) { MLIRContext *context) {
/// load(memrefcast) -> load /// load(memrefcast) -> load
results.push_back( results.insert<MemRefCastFolder>(getOperationName(), context);
llvm::make_unique<MemRefCastFolder>(getOperationName(), context));
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -2007,8 +2001,7 @@ static LogicalResult verify(StoreOp op) {
void StoreOp::getCanonicalizationPatterns(OwningRewritePatternList &results, void StoreOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) { MLIRContext *context) {
/// store(memrefcast) -> store /// store(memrefcast) -> store
results.push_back( results.insert<MemRefCastFolder>(getOperationName(), context);
llvm::make_unique<MemRefCastFolder>(getOperationName(), context));
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -1243,8 +1243,7 @@ struct FuncOpSignatureConversion : public ConversionPattern {
void mlir::populateFuncOpTypeConversionPattern( void mlir::populateFuncOpTypeConversionPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx, OwningRewritePatternList &patterns, MLIRContext *ctx,
TypeConverter &converter) { TypeConverter &converter) {
RewriteListBuilder<FuncOpSignatureConversion>::build(patterns, ctx, patterns.insert<FuncOpSignatureConversion>(ctx, converter);
converter);
} }
/// This function converts the type signature of the given block, by invoking /// This function converts the type signature of the given block, by invoking

View File

@ -507,10 +507,11 @@ public:
void mlir::populateAffineToStdConversionPatterns( void mlir::populateAffineToStdConversionPatterns(
OwningRewritePatternList &patterns, MLIRContext *ctx) { OwningRewritePatternList &patterns, MLIRContext *ctx) {
RewriteListBuilder<AffineApplyLowering, AffineDmaStartLowering, patterns
AffineDmaWaitLowering, AffineLoadLowering, .insert<AffineApplyLowering, AffineDmaStartLowering,
AffineStoreLowering, AffineForLowering, AffineIfLowering, AffineDmaWaitLowering, AffineLoadLowering, AffineStoreLowering,
AffineTerminatorLowering>::build(patterns, ctx); AffineForLowering, AffineIfLowering, AffineTerminatorLowering>(
ctx);
} }
namespace { namespace {

View File

@ -365,12 +365,8 @@ struct LowerVectorTransfersPass
void runOnFunction() { void runOnFunction() {
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
auto *context = &getContext(); auto *context = &getContext();
patterns.push_back( patterns.insert<VectorTransferRewriter<VectorTransferReadOp>,
llvm::make_unique<VectorTransferRewriter<VectorTransferReadOp>>( VectorTransferRewriter<VectorTransferWriteOp>>(context);
context));
patterns.push_back(
llvm::make_unique<VectorTransferRewriter<VectorTransferWriteOp>>(
context));
applyPatternsGreedily(getFunction(), std::move(patterns)); applyPatternsGreedily(getFunction(), std::move(patterns));
} }
}; };

View File

@ -44,8 +44,8 @@ namespace {
class GreedyPatternRewriteDriver : public PatternRewriter { class GreedyPatternRewriteDriver : public PatternRewriter {
public: public:
explicit GreedyPatternRewriteDriver(MLIRContext *ctx, explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
OwningRewritePatternList &&patterns) OwningRewritePatternList &patterns)
: PatternRewriter(ctx), matcher(std::move(patterns)) { : PatternRewriter(ctx), matcher(patterns) {
worklist.reserve(64); worklist.reserve(64);
} }
@ -224,7 +224,7 @@ bool mlir::applyPatternsGreedily(Operation *op,
if (!op->isKnownIsolatedFromAbove()) if (!op->isKnownIsolatedFromAbove())
return false; return false;
GreedyPatternRewriteDriver driver(op->getContext(), std::move(patterns)); GreedyPatternRewriteDriver driver(op->getContext(), patterns);
bool converged = driver.simplify(op, maxPatternMatchIterations); bool converged = driver.simplify(op, maxPatternMatchIterations);
LLVM_DEBUG(if (!converged) { LLVM_DEBUG(if (!converged) {
llvm::dbgs() << "The pattern rewrite doesn't converge after scanning " llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "

View File

@ -41,7 +41,7 @@ struct TestPatternDriver : public FunctionPass<TestPatternDriver> {
populateWithGenerated(&getContext(), &patterns); populateWithGenerated(&getContext(), &patterns);
// Verify named pattern is generated with expected name. // Verify named pattern is generated with expected name.
RewriteListBuilder<TestNamedPatternRule>::build(patterns, &getContext()); patterns.insert<TestNamedPatternRule>(&getContext());
applyPatternsGreedily(getFunction(), std::move(patterns)); applyPatternsGreedily(getFunction(), std::move(patterns));
} }
@ -193,9 +193,9 @@ struct TestLegalizePatternDriver
TestTypeConverter converter; TestTypeConverter converter;
mlir::OwningRewritePatternList patterns; mlir::OwningRewritePatternList patterns;
populateWithGenerated(&getContext(), &patterns); populateWithGenerated(&getContext(), &patterns);
RewriteListBuilder<TestRegionRewriteBlockMovement, TestRegionRewriteUndo, patterns.insert<TestRegionRewriteBlockMovement, TestRegionRewriteUndo,
TestDropOp, TestPassthroughInvalidOp, TestDropOp, TestPassthroughInvalidOp, TestSplitReturnType>(
TestSplitReturnType>::build(patterns, &getContext()); &getContext());
mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(), mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(),
converter); converter);

View File

@ -133,7 +133,7 @@ static LogicalResult runMLIRPasses(ModuleOp m) {
pm.addPass(createConvertToLLVMIRPass([](LLVMTypeConverter &converter, pm.addPass(createConvertToLLVMIRPass([](LLVMTypeConverter &converter,
OwningRewritePatternList &patterns) { OwningRewritePatternList &patterns) {
populateStdToLLVMConversionPatterns(converter, patterns); populateStdToLLVMConversionPatterns(converter, patterns);
patterns.push_back(llvm::make_unique<GPULaunchFuncOpLowering>(converter)); patterns.insert<GPULaunchFuncOpLowering>(converter);
})); }));
pm.addPass(createLowerGpuOpsToNVVMOpsPass()); pm.addPass(createLowerGpuOpsToNVVMOpsPass());
pm.addPass(createConvertGPUKernelToCubinPass(&compilePtxToCubin)); pm.addPass(createConvertGPUKernelToCubinPass(&compilePtxToCubin));

View File

@ -935,8 +935,7 @@ static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
os << "void populateWithGenerated(MLIRContext *context, " os << "void populateWithGenerated(MLIRContext *context, "
<< "OwningRewritePatternList *patterns) {\n"; << "OwningRewritePatternList *patterns) {\n";
for (const auto &name : rewriterNames) { for (const auto &name : rewriterNames) {
os << " patterns->push_back(llvm::make_unique<" << name os << " patterns->insert<" << name << ">(context);\n";
<< ">(context));\n";
} }
os << "}\n"; os << "}\n";
} }