diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc index e1d89cd2807..9e8cf6c4ce8 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc @@ -506,7 +506,7 @@ OpFoldResult ReshapeOp::fold(ArrayRef operands) { void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.push_back(llvm::make_unique(context)); + results.insert(context); } //===----------------------------------------------------------------------===// @@ -591,7 +591,7 @@ struct DropFakeQuant : public RewritePattern { void FakeQuantOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.push_back(llvm::make_unique(context)); + results.insert(context); } //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc index faf80f3acb8..0c8176e96f6 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc @@ -205,9 +205,9 @@ void LegalizeTF::runOnFunction() { // Add the generated patterns to the list. populateWithGenerated(ctx, &patterns); - RewriteListBuilder::build(patterns, ctx); + patterns.insert(ctx); applyPatternsGreedily(func, std::move(patterns)); } diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc index da910be0e6e..d93c01a806c 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc @@ -240,8 +240,8 @@ void Optimize::runOnFunction() { auto func = getFunction(); // Add the generated patterns to the list. TFL::populateWithGenerated(ctx, &patterns); - RewriteListBuilder::build(patterns, ctx); + patterns.insert(ctx); applyPatternsGreedily(func, std::move(patterns)); } diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc index 166b732f936..252381d072e 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc @@ -381,8 +381,7 @@ void PrepareTFPass::runOnFunction() { // parameters from the TF Quant ops, thus this pattern should run with the // first `applyPatternsGreedily` method, which would otherwise removes the // TF FakeQuant ops by the constant folding. - patterns.push_back( - llvm::make_unique(&getContext())); + patterns.insert(&getContext()); TFL::populateWithGenerated(&getContext(), &patterns); // TODO(karimnosseir): Split to separate pass probably after // deciding on long term plan for this optimization. @@ -394,9 +393,8 @@ void PrepareTFPass::runOnFunction() { // Load the generated pattern again, so new quantization pass-through // will be applied. TFL::populateWithGenerated(&getContext(), &patterns); - patterns.push_back(llvm::make_unique(&getContext())); - patterns.push_back( - llvm::make_unique(&getContext())); + patterns.insert( + &getContext()); applyPatternsGreedily(func, std::move(patterns)); } diff --git a/tensorflow/compiler/mlir/lite/transforms/quantize.cc b/tensorflow/compiler/mlir/lite/transforms/quantize.cc index 91bb26a976b..78abdd476ed 100644 --- a/tensorflow/compiler/mlir/lite/transforms/quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/quantize.cc @@ -55,8 +55,8 @@ void QuantizePass::runOnFunction() { auto func = getFunction(); auto* ctx = func.getContext(); TFL::populateWithGenerated(ctx, &patterns); - mlir::RewriteListBuilder>::build(patterns, ctx); + patterns.insert>(ctx); applyPatternsGreedily(func, std::move(patterns)); } } // namespace diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index 1d6abe6c848..83de56bc0a4 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -106,7 +106,7 @@ namespace { void AddOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - RewriteListBuilder::build(results, context); + results.insert(context); } //===----------------------------------------------------------------------===// @@ -115,7 +115,7 @@ void AddOp::getCanonicalizationPatterns(OwningRewritePatternList &results, void AddV2Op::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - RewriteListBuilder::build(results, context); + results.insert(context); } //===----------------------------------------------------------------------===// @@ -144,7 +144,7 @@ struct AssertWithTrue : public OpRewritePattern { void AssertOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - RewriteListBuilder::build(results, context); + results.insert(context); } //===----------------------------------------------------------------------===// @@ -153,7 +153,7 @@ void AssertOp::getCanonicalizationPatterns(OwningRewritePatternList &results, void BitcastOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - RewriteListBuilder::build(results, context); + results.insert(context); } //===----------------------------------------------------------------------===// @@ -174,7 +174,7 @@ static LogicalResult Verify(BroadcastToOp op) { void CastOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - RewriteListBuilder::build(results, context); + results.insert(context); } //===----------------------------------------------------------------------===// @@ -183,7 +183,7 @@ void CastOp::getCanonicalizationPatterns(OwningRewritePatternList &results, void ConjOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - RewriteListBuilder::build(results, context); + results.insert(context); } //===----------------------------------------------------------------------===// @@ -239,7 +239,7 @@ void ConstOp::build(Builder *builder, OperationState *result, Type type, void DivOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - RewriteListBuilder::build(results, context); + results.insert(context); } //===----------------------------------------------------------------------===// @@ -415,7 +415,7 @@ static LogicalResult Verify(IfOp op) { void InvertOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - RewriteListBuilder::build(results, context); + results.insert(context); } //===----------------------------------------------------------------------===// @@ -449,7 +449,7 @@ OpFoldResult LeakyReluOp::fold(ArrayRef operands) { void LogOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - RewriteListBuilder::build(results, context); + results.insert(context); } //===----------------------------------------------------------------------===// @@ -458,10 +458,9 @@ void LogOp::getCanonicalizationPatterns(OwningRewritePatternList &results, void LogicalNotOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { - RewriteListBuilder::build(results, - context); + results.insert(context); } //===----------------------------------------------------------------------===// @@ -470,7 +469,7 @@ void LogicalNotOp::getCanonicalizationPatterns( void NegOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - RewriteListBuilder::build(results, context); + results.insert(context); } //===----------------------------------------------------------------------===// @@ -479,7 +478,7 @@ void NegOp::getCanonicalizationPatterns(OwningRewritePatternList &results, void ReciprocalOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { - RewriteListBuilder::build(results, context); + results.insert(context); } //===----------------------------------------------------------------------===// @@ -538,7 +537,7 @@ void RankOp::build(Builder *builder, OperationState *result, Value *input) { void RealDivOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - RewriteListBuilder::build(results, context); + results.insert(context); } //===----------------------------------------------------------------------===// @@ -698,7 +697,7 @@ static LogicalResult Verify(SoftmaxOp op) { void SquareOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - RewriteListBuilder::build(results, context); + results.insert(context); } //===----------------------------------------------------------------------===// @@ -707,7 +706,7 @@ void SquareOp::getCanonicalizationPatterns(OwningRewritePatternList &results, void SubOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - RewriteListBuilder::build(results, context); + results.insert(context); } //===----------------------------------------------------------------------===// @@ -778,7 +777,7 @@ void TransposeOp::build(Builder *builder, OperationState *result, Value *x, void TruncateDivOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { - RewriteListBuilder::build(results, context); + results.insert(context); } //===----------------------------------------------------------------------===// @@ -867,7 +866,7 @@ static LogicalResult Verify(WhileOp op) { void XdivyOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - RewriteListBuilder::build(results, context); + results.insert(context); } //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc index 623121d0c72..b62bb7ef4ec 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc @@ -133,10 +133,8 @@ void LegalizeToStandard::runOnFunction() { auto func = getFunction(); mlir::XLA::populateWithGenerated(func.getContext(), &patterns); - patterns.push_back( - llvm::make_unique(&getContext())); - patterns.push_back( - llvm::make_unique(&getContext())); + patterns.insert( + &getContext()); applyPatternsGreedily(func, std::move(patterns)); } diff --git a/third_party/mlir/include/mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h b/third_party/mlir/include/mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h index e8ab2732d31..78e4356607f 100644 --- a/third_party/mlir/include/mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h +++ b/third_party/mlir/include/mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h @@ -29,7 +29,7 @@ class MLIRContext; class RewritePattern; // Owning list of rewriting patterns. -using OwningRewritePatternList = std::vector>; +class OwningRewritePatternList; /// Collect a set of patterns to lower from loop.for, loop.if, and /// loop.terminator to CFG operations within the Standard dialect, in particular diff --git a/third_party/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h b/third_party/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h index 361294a729e..941e382905f 100644 --- a/third_party/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h +++ b/third_party/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h @@ -38,7 +38,7 @@ class RewritePattern; class Type; // Owning list of rewriting patterns. -using OwningRewritePatternList = std::vector>; +class OwningRewritePatternList; /// Type for a callback constructing the owning list of patterns for the /// conversion to the LLVMIR dialect. The callback is expected to append diff --git a/third_party/mlir/include/mlir/IR/OperationSupport.h b/third_party/mlir/include/mlir/IR/OperationSupport.h index c76f1d620af..204da29b39a 100644 --- a/third_party/mlir/include/mlir/IR/OperationSupport.h +++ b/third_party/mlir/include/mlir/IR/OperationSupport.h @@ -57,9 +57,7 @@ class Value; /// either OpTy or OperandAdaptor seamlessly. template using OperandAdaptor = typename OpTy::OperandAdaptor; -/// This is a vector that owns the patterns inside of it. -using OwningPatternList = std::vector>; -using OwningRewritePatternList = std::vector>; +class OwningRewritePatternList; enum class OperationProperty { /// This bit is set for an operation if it is a commutative operation: that diff --git a/third_party/mlir/include/mlir/IR/PatternMatch.h b/third_party/mlir/include/mlir/IR/PatternMatch.h index d739a804438..e3897b1d63a 100644 --- a/third_party/mlir/include/mlir/IR/PatternMatch.h +++ b/third_party/mlir/include/mlir/IR/PatternMatch.h @@ -394,8 +394,39 @@ private: // Pattern-driven rewriters //===----------------------------------------------------------------------===// -/// This is a vector that owns the patterns inside of it. -using OwningRewritePatternList = std::vector>; +class OwningRewritePatternList { + using PatternListT = std::vector>; + +public: + PatternListT::iterator begin() { return patterns.begin(); } + PatternListT::iterator end() { return patterns.end(); } + PatternListT::const_iterator begin() const { return patterns.begin(); } + PatternListT::const_iterator end() const { return patterns.end(); } + + //===--------------------------------------------------------------------===// + // Pattern Insertion + //===--------------------------------------------------------------------===// + + void insert(RewritePattern *pattern) { patterns.emplace_back(pattern); } + + /// Add an instance of each of the pattern types 'Ts' to the pattern list with + /// the given arguments. + // Note: ConstructorArg is necessary here to separate the two variadic lists. + template + void insert(ConstructorArg &&arg, ConstructorArgs &&... args) { + // The following expands a call to emplace_back for each of the pattern + // types 'Ts'. This magic is necessary due to a limitation in the places + // that a parameter pack can be expanded in c++11. + // FIXME: In c++17 this can be simplified by using 'fold expressions'. + using dummy = int[]; + (void)dummy{ + 0, (patterns.emplace_back(llvm::make_unique(arg, args...)), 0)...}; + } + +private: + PatternListT patterns; +}; /// This class manages optimization and execution of a group of rewrite /// patterns, providing an API for finding and applying, the best match against @@ -404,7 +435,7 @@ using OwningRewritePatternList = std::vector>; class RewritePatternMatcher { public: /// Create a RewritePatternMatcher with the specified set of patterns. - explicit RewritePatternMatcher(OwningRewritePatternList &&patterns); + explicit RewritePatternMatcher(OwningRewritePatternList &patterns); /// Try to match the given operation to a pattern and rewrite it. Return /// true if any pattern matches. @@ -416,7 +447,7 @@ private: /// The group of patterns that are matched for optimization through this /// matcher. - OwningRewritePatternList patterns; + std::vector patterns; }; /// Rewrite the regions of the specified operation, which must be isolated from @@ -427,29 +458,6 @@ private: /// bool applyPatternsGreedily(Operation *op, OwningRewritePatternList &&patterns); -/// Helper class to create a list of rewrite patterns given a list of their -/// types and a list of attributes perfect-forwarded to each of the conversion -/// constructors. -template struct RewriteListBuilder { - template - static void build(OwningRewritePatternList &patterns, - ConstructorArgs &&... constructorArgs) { - RewriteListBuilder::build( - patterns, std::forward(constructorArgs)...); - RewriteListBuilder::build( - patterns, std::forward(constructorArgs)...); - } -}; - -// Template specialization to stop recursion. -template struct RewriteListBuilder { - template - static void build(OwningRewritePatternList &patterns, - ConstructorArgs &&... constructorArgs) { - patterns.emplace_back(llvm::make_unique( - std::forward(constructorArgs)...)); - } -}; } // end namespace mlir #endif // MLIR_PATTERN_MATCH_H diff --git a/third_party/mlir/include/mlir/Transforms/LowerAffine.h b/third_party/mlir/include/mlir/Transforms/LowerAffine.h index 9ad3f66def5..5fae4763bf7 100644 --- a/third_party/mlir/include/mlir/Transforms/LowerAffine.h +++ b/third_party/mlir/include/mlir/Transforms/LowerAffine.h @@ -32,7 +32,7 @@ class RewritePattern; class Value; // Owning list of rewriting patterns. -using OwningRewritePatternList = std::vector>; +class OwningRewritePatternList; /// Emit code that computes the given affine expression using standard /// arithmetic operations applied to the provided dimension and symbol values. diff --git a/third_party/mlir/lib/AffineOps/AffineOps.cpp b/third_party/mlir/lib/AffineOps/AffineOps.cpp index 9a026231ab2..767c2e344d9 100644 --- a/third_party/mlir/lib/AffineOps/AffineOps.cpp +++ b/third_party/mlir/lib/AffineOps/AffineOps.cpp @@ -708,7 +708,7 @@ struct SimplifyAffineApply : public OpRewritePattern { void AffineApplyOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { - results.push_back(llvm::make_unique(context)); + results.insert(context); } //===----------------------------------------------------------------------===// @@ -912,8 +912,7 @@ LogicalResult AffineDmaStartOp::verify() { void AffineDmaStartOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { /// dma_start(memrefcast) -> dma_start - results.push_back( - llvm::make_unique(getOperationName(), context)); + results.insert(getOperationName(), context); } //===----------------------------------------------------------------------===// @@ -989,8 +988,7 @@ LogicalResult AffineDmaWaitOp::verify() { void AffineDmaWaitOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { /// dma_wait(memrefcast) -> dma_wait - results.push_back( - llvm::make_unique(getOperationName(), context)); + results.insert(getOperationName(), context); } //===----------------------------------------------------------------------===// @@ -1333,7 +1331,7 @@ struct AffineForLoopBoundFolder : public OpRewritePattern { void AffineForOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.push_back(llvm::make_unique(context)); + results.insert(context); } AffineBound AffineForOp::getLowerBound() { @@ -1659,8 +1657,7 @@ LogicalResult AffineLoadOp::verify() { void AffineLoadOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { /// load(memrefcast) -> load - results.push_back( - llvm::make_unique(getOperationName(), context)); + results.insert(getOperationName(), context); } //===----------------------------------------------------------------------===// @@ -1752,8 +1749,7 @@ LogicalResult AffineStoreOp::verify() { void AffineStoreOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { /// load(memrefcast) -> load - results.push_back( - llvm::make_unique(getOperationName(), context)); + results.insert(getOperationName(), context); } #define GET_OP_CLASSES diff --git a/third_party/mlir/lib/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.cpp b/third_party/mlir/lib/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.cpp index c37decf69e6..034aa22f922 100644 --- a/third_party/mlir/lib/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.cpp +++ b/third_party/mlir/lib/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.cpp @@ -258,8 +258,7 @@ IfLowering::matchAndRewrite(IfOp ifOp, PatternRewriter &rewriter) const { void mlir::populateLoopToStdConversionPatterns( OwningRewritePatternList &patterns, MLIRContext *ctx) { - RewriteListBuilder::build( - patterns, ctx); + patterns.insert(ctx); } void ControlFlowToCFGPass::runOnFunction() { diff --git a/third_party/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/third_party/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp index 4eadb874908..58f01fc6689 100644 --- a/third_party/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/third_party/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -104,8 +104,7 @@ void GPUToSPIRVPass::runOnModule() { SPIRVTypeConverter typeConverter(context); SPIRVEntryFnTypeConverter entryFnConverter(context); OwningRewritePatternList patterns; - RewriteListBuilder::build( - patterns, context, typeConverter, entryFnConverter); + patterns.insert(context, typeConverter, entryFnConverter); populateStandardToSPIRVPatterns(context, patterns); ConversionTarget target(*context); diff --git a/third_party/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/third_party/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index af8812c8cf4..09ddcd1e475 100644 --- a/third_party/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/third_party/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -1023,7 +1023,7 @@ void mlir::LLVM::ensureDistinctSuccessors(ModuleOp m) { void mlir::populateStdToLLVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { // FIXME: this should be tablegen'ed - RewriteListBuilder< + patterns.insert< AddFOpLowering, AddIOpLowering, AndOpLowering, AllocOpLowering, BranchOpLowering, CallIndirectOpLowering, CallOpLowering, CmpIOpLowering, CondBranchOpLowering, ConstLLVMOpLowering, DeallocOpLowering, @@ -1032,8 +1032,7 @@ void mlir::populateStdToLLVMConversionPatterns( MemRefCastOpLowering, MulFOpLowering, MulIOpLowering, OrOpLowering, RemISOpLowering, RemIUOpLowering, RemFOpLowering, ReturnOpLowering, SelectOpLowering, SIToFPLowering, StoreOpLowering, SubFOpLowering, - SubIOpLowering, XOrOpLowering>::build(patterns, *converter.getDialect(), - converter); + SubIOpLowering, XOrOpLowering>(*converter.getDialect(), converter); } // Convert types using the stored LLVM IR module. diff --git a/third_party/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/third_party/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp index d32d8668046..067f2aeda06 100644 --- a/third_party/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp +++ b/third_party/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp @@ -201,6 +201,6 @@ void populateStandardToSPIRVPatterns(MLIRContext *context, OwningRewritePatternList &patterns) { populateWithGenerated(context, &patterns); // Add the return op conversion. - RewriteListBuilder::build(patterns, context); + patterns.insert(context); } } // namespace mlir diff --git a/third_party/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp b/third_party/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp index dafc8e711f5..d2f3881710c 100644 --- a/third_party/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp +++ b/third_party/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp @@ -368,8 +368,7 @@ void LowerUniformRealMathPass::runOnFunction() { auto fn = getFunction(); OwningRewritePatternList patterns; auto *context = &getContext(); - patterns.push_back(llvm::make_unique(context)); - patterns.push_back(llvm::make_unique(context)); + patterns.insert(context); applyPatternsGreedily(fn, std::move(patterns)); } @@ -389,7 +388,7 @@ void LowerUniformCastsPass::runOnFunction() { auto fn = getFunction(); OwningRewritePatternList patterns; auto *context = &getContext(); - patterns.push_back(llvm::make_unique(context)); + patterns.insert(context); applyPatternsGreedily(fn, std::move(patterns)); } diff --git a/third_party/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/third_party/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index bda5979939c..2fbaa49f56e 100644 --- a/third_party/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/third_party/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -372,7 +372,7 @@ class PropagateConstantBounds : public OpRewritePattern { void LaunchOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - RewriteListBuilder::build(results, context); + results.insert(context); } //===----------------------------------------------------------------------===// diff --git a/third_party/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp b/third_party/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp index e237e8b6eb2..3bd49d43adc 100644 --- a/third_party/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp +++ b/third_party/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp @@ -60,8 +60,7 @@ public: void StorageCastOp::getCanonicalizationPatterns( OwningRewritePatternList &patterns, MLIRContext *context) { - patterns.push_back( - llvm::make_unique(context)); + patterns.insert(context); } QuantizationDialect::QuantizationDialect(MLIRContext *context) diff --git a/third_party/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp b/third_party/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp index 8469fa2ea70..2276fbd21c9 100644 --- a/third_party/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp +++ b/third_party/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp @@ -108,7 +108,7 @@ void ConvertConstPass::runOnFunction() { OwningRewritePatternList patterns; auto func = getFunction(); auto *context = &getContext(); - patterns.push_back(llvm::make_unique(context)); + patterns.insert(context); applyPatternsGreedily(func, std::move(patterns)); } diff --git a/third_party/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp b/third_party/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp index 32d8c8a81c1..8f5d1b33c64 100644 --- a/third_party/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp +++ b/third_party/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp @@ -97,8 +97,7 @@ void ConvertSimulatedQuantPass::runOnFunction() { OwningRewritePatternList patterns; auto func = getFunction(); auto *context = &getContext(); - patterns.push_back( - llvm::make_unique(context, &hadFailure)); + patterns.insert(context, &hadFailure); applyPatternsGreedily(func, std::move(patterns)); if (hadFailure) signalPassFailure(); diff --git a/third_party/mlir/lib/IR/PatternMatch.cpp b/third_party/mlir/lib/IR/PatternMatch.cpp index 5010b845c78..94fa7ab43f7 100644 --- a/third_party/mlir/lib/IR/PatternMatch.cpp +++ b/third_party/mlir/lib/IR/PatternMatch.cpp @@ -149,12 +149,13 @@ void PatternRewriter::updatedRootInPlace( //===----------------------------------------------------------------------===// RewritePatternMatcher::RewritePatternMatcher( - OwningRewritePatternList &&patterns) - : patterns(std::move(patterns)) { + OwningRewritePatternList &patterns) { + for (auto &pattern : patterns) + this->patterns.push_back(pattern.get()); + // Sort the patterns by benefit to simplify the matching logic. std::stable_sort(this->patterns.begin(), this->patterns.end(), - [](const std::unique_ptr &l, - const std::unique_ptr &r) { + [](RewritePattern *l, RewritePattern *r) { return r->getBenefit() < l->getBenefit(); }); } @@ -162,7 +163,7 @@ RewritePatternMatcher::RewritePatternMatcher( /// Try to match the given operation to a pattern and rewrite it. bool RewritePatternMatcher::matchAndRewrite(Operation *op, PatternRewriter &rewriter) { - for (auto &pattern : patterns) { + for (auto *pattern : patterns) { // Ignore patterns that are for the wrong root or are impossible to match. if (pattern->getRootKind() != op->getName() || pattern->getBenefit().isImpossibleToMatch()) diff --git a/third_party/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp b/third_party/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp index 6b62a8e1340..7c2ea5945f4 100644 --- a/third_party/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp +++ b/third_party/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp @@ -678,12 +678,11 @@ static void populateLinalgToLLVMConversionPatterns(LinalgTypeConverter &converter, OwningRewritePatternList &patterns, MLIRContext *ctx) { - RewriteListBuilder, LinalgOpConversion, - LoadOpConversion, RangeOpConversion, SliceOpConversion, - StoreOpConversion, ViewOpConversion>::build(patterns, ctx, - converter); + patterns.insert, LinalgOpConversion, + LoadOpConversion, RangeOpConversion, SliceOpConversion, + StoreOpConversion, ViewOpConversion>(ctx, converter); } namespace { diff --git a/third_party/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp b/third_party/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp index 6b376db8516..3de89137c3c 100644 --- a/third_party/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp +++ b/third_party/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp @@ -60,12 +60,9 @@ void RemoveInstrumentationPass::runOnFunction() { OwningRewritePatternList patterns; auto func = getFunction(); auto *context = &getContext(); - patterns.push_back( - llvm::make_unique>(context)); - patterns.push_back( - llvm::make_unique>(context)); - patterns.push_back( - llvm::make_unique>(context)); + patterns.insert, + RemoveIdentityOpRewrite, + RemoveIdentityOpRewrite>(context); applyPatternsGreedily(func, std::move(patterns)); } diff --git a/third_party/mlir/lib/StandardOps/Ops.cpp b/third_party/mlir/lib/StandardOps/Ops.cpp index df99f00c110..9ecd99a5169 100644 --- a/third_party/mlir/lib/StandardOps/Ops.cpp +++ b/third_party/mlir/lib/StandardOps/Ops.cpp @@ -365,8 +365,7 @@ struct SimplifyDeadAlloc : public OpRewritePattern { void AllocOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - RewriteListBuilder::build(results, - context); + results.insert(context); } //===----------------------------------------------------------------------===// @@ -544,8 +543,7 @@ static LogicalResult verify(CallIndirectOp op) { void CallIndirectOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { - results.push_back( - llvm::make_unique(context)); + results.insert(context); } //===----------------------------------------------------------------------===// @@ -1015,7 +1013,7 @@ static void print(OpAsmPrinter *p, CondBranchOp op) { void CondBranchOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { - results.push_back(llvm::make_unique(context)); + results.insert(context); } //===----------------------------------------------------------------------===// @@ -1231,9 +1229,8 @@ static LogicalResult verify(DeallocOp op) { void DeallocOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { /// dealloc(memrefcast) -> dealloc - results.push_back( - llvm::make_unique(getOperationName(), context)); - results.push_back(llvm::make_unique(context)); + results.insert(getOperationName(), context); + results.insert(context); } //===----------------------------------------------------------------------===// @@ -1497,8 +1494,7 @@ LogicalResult DmaStartOp::verify() { void DmaStartOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { /// dma_start(memrefcast) -> dma_start - results.push_back( - llvm::make_unique(getOperationName(), context)); + results.insert(getOperationName(), context); } // --------------------------------------------------------------------------- @@ -1561,8 +1557,7 @@ ParseResult DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) { void DmaWaitOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { /// dma_wait(memrefcast) -> dma_wait - results.push_back( - llvm::make_unique(getOperationName(), context)); + results.insert(getOperationName(), context); } //===----------------------------------------------------------------------===// @@ -1695,8 +1690,7 @@ static LogicalResult verify(LoadOp op) { void LoadOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { /// load(memrefcast) -> load - results.push_back( - llvm::make_unique(getOperationName(), context)); + results.insert(getOperationName(), context); } //===----------------------------------------------------------------------===// @@ -2007,8 +2001,7 @@ static LogicalResult verify(StoreOp op) { void StoreOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { /// store(memrefcast) -> store - results.push_back( - llvm::make_unique(getOperationName(), context)); + results.insert(getOperationName(), context); } //===----------------------------------------------------------------------===// diff --git a/third_party/mlir/lib/Transforms/DialectConversion.cpp b/third_party/mlir/lib/Transforms/DialectConversion.cpp index 50c636f708e..6f264b0af35 100644 --- a/third_party/mlir/lib/Transforms/DialectConversion.cpp +++ b/third_party/mlir/lib/Transforms/DialectConversion.cpp @@ -1243,8 +1243,7 @@ struct FuncOpSignatureConversion : public ConversionPattern { void mlir::populateFuncOpTypeConversionPattern( OwningRewritePatternList &patterns, MLIRContext *ctx, TypeConverter &converter) { - RewriteListBuilder::build(patterns, ctx, - converter); + patterns.insert(ctx, converter); } /// This function converts the type signature of the given block, by invoking diff --git a/third_party/mlir/lib/Transforms/LowerAffine.cpp b/third_party/mlir/lib/Transforms/LowerAffine.cpp index f35f963b8ae..1c558efd8e4 100644 --- a/third_party/mlir/lib/Transforms/LowerAffine.cpp +++ b/third_party/mlir/lib/Transforms/LowerAffine.cpp @@ -507,10 +507,11 @@ public: void mlir::populateAffineToStdConversionPatterns( OwningRewritePatternList &patterns, MLIRContext *ctx) { - RewriteListBuilder::build(patterns, ctx); + patterns + .insert( + ctx); } namespace { diff --git a/third_party/mlir/lib/Transforms/LowerVectorTransfers.cpp b/third_party/mlir/lib/Transforms/LowerVectorTransfers.cpp index 3585e2befd6..ef67488023f 100644 --- a/third_party/mlir/lib/Transforms/LowerVectorTransfers.cpp +++ b/third_party/mlir/lib/Transforms/LowerVectorTransfers.cpp @@ -365,12 +365,8 @@ struct LowerVectorTransfersPass void runOnFunction() { OwningRewritePatternList patterns; auto *context = &getContext(); - patterns.push_back( - llvm::make_unique>( - context)); - patterns.push_back( - llvm::make_unique>( - context)); + patterns.insert, + VectorTransferRewriter>(context); applyPatternsGreedily(getFunction(), std::move(patterns)); } }; diff --git a/third_party/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/third_party/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index 52952178b37..1df4ceec8f3 100644 --- a/third_party/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/third_party/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -44,8 +44,8 @@ namespace { class GreedyPatternRewriteDriver : public PatternRewriter { public: explicit GreedyPatternRewriteDriver(MLIRContext *ctx, - OwningRewritePatternList &&patterns) - : PatternRewriter(ctx), matcher(std::move(patterns)) { + OwningRewritePatternList &patterns) + : PatternRewriter(ctx), matcher(patterns) { worklist.reserve(64); } @@ -224,7 +224,7 @@ bool mlir::applyPatternsGreedily(Operation *op, if (!op->isKnownIsolatedFromAbove()) return false; - GreedyPatternRewriteDriver driver(op->getContext(), std::move(patterns)); + GreedyPatternRewriteDriver driver(op->getContext(), patterns); bool converged = driver.simplify(op, maxPatternMatchIterations); LLVM_DEBUG(if (!converged) { llvm::dbgs() << "The pattern rewrite doesn't converge after scanning " diff --git a/third_party/mlir/test/lib/TestDialect/TestPatterns.cpp b/third_party/mlir/test/lib/TestDialect/TestPatterns.cpp index 201dfc3005c..ed94eed4fdd 100644 --- a/third_party/mlir/test/lib/TestDialect/TestPatterns.cpp +++ b/third_party/mlir/test/lib/TestDialect/TestPatterns.cpp @@ -41,7 +41,7 @@ struct TestPatternDriver : public FunctionPass { populateWithGenerated(&getContext(), &patterns); // Verify named pattern is generated with expected name. - RewriteListBuilder::build(patterns, &getContext()); + patterns.insert(&getContext()); applyPatternsGreedily(getFunction(), std::move(patterns)); } @@ -193,9 +193,9 @@ struct TestLegalizePatternDriver TestTypeConverter converter; mlir::OwningRewritePatternList patterns; populateWithGenerated(&getContext(), &patterns); - RewriteListBuilder::build(patterns, &getContext()); + patterns.insert( + &getContext()); mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(), converter); diff --git a/third_party/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp b/third_party/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp index edf6aeae469..f75413fdaed 100644 --- a/third_party/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp +++ b/third_party/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp @@ -133,7 +133,7 @@ static LogicalResult runMLIRPasses(ModuleOp m) { pm.addPass(createConvertToLLVMIRPass([](LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { populateStdToLLVMConversionPatterns(converter, patterns); - patterns.push_back(llvm::make_unique(converter)); + patterns.insert(converter); })); pm.addPass(createLowerGpuOpsToNVVMOpsPass()); pm.addPass(createConvertGPUKernelToCubinPass(&compilePtxToCubin)); diff --git a/third_party/mlir/tools/mlir-tblgen/RewriterGen.cpp b/third_party/mlir/tools/mlir-tblgen/RewriterGen.cpp index d408ecfa5eb..24eeaf50d78 100644 --- a/third_party/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/third_party/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -935,8 +935,7 @@ static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) { os << "void populateWithGenerated(MLIRContext *context, " << "OwningRewritePatternList *patterns) {\n"; for (const auto &name : rewriterNames) { - os << " patterns->push_back(llvm::make_unique<" << name - << ">(context));\n"; + os << " patterns->insert<" << name << ">(context);\n"; } os << "}\n"; }