NFC: Implement OwningRewritePatternList as a class instead of a using directive.

This allows for proper forward declaration, as opposed to leaking the internal implementation via a using directive. This also allows for all pattern building to go through 'insert' methods on the OwningRewritePatternList, replacing uses of 'push_back' and 'RewriteListBuilder'.

PiperOrigin-RevId: 261816316
This commit is contained in:
River Riddle 2019-08-05 18:37:56 -07:00 committed by TensorFlower Gardener
parent 85a9058eff
commit 4add221b50
33 changed files with 129 additions and 153 deletions

View File

@ -506,7 +506,7 @@ OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.push_back(llvm::make_unique<RemoveAdjacentReshape>(context));
results.insert<RemoveAdjacentReshape>(context);
}
//===----------------------------------------------------------------------===//
@ -591,7 +591,7 @@ struct DropFakeQuant : public RewritePattern {
void FakeQuantOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.push_back(llvm::make_unique<DropFakeQuant>(context));
results.insert<DropFakeQuant>(context);
}
//===----------------------------------------------------------------------===//

View File

@ -205,9 +205,9 @@ void LegalizeTF::runOnFunction() {
// Add the generated patterns to the list.
populateWithGenerated(ctx, &patterns);
RewriteListBuilder<ConvertTFConcatOp, ConvertTFConcatV2Op, ConvertTFMatMulOp,
ConvertTFPackOp, ConvertTFSplitOp, ConvertTFSplitVOp,
ConvertTFUnpackOp>::build(patterns, ctx);
patterns.insert<ConvertTFConcatOp, ConvertTFConcatV2Op, ConvertTFMatMulOp,
ConvertTFPackOp, ConvertTFSplitOp, ConvertTFSplitVOp,
ConvertTFUnpackOp>(ctx);
applyPatternsGreedily(func, std::move(patterns));
}

View File

@ -240,8 +240,8 @@ void Optimize::runOnFunction() {
auto func = getFunction();
// Add the generated patterns to the list.
TFL::populateWithGenerated(ctx, &patterns);
RewriteListBuilder<FuseFullyConnectedAndAdd, FuseFullyConnectedAndRelu,
PadStridedSliceDims>::build(patterns, ctx);
patterns.insert<FuseFullyConnectedAndAdd, FuseFullyConnectedAndRelu,
PadStridedSliceDims>(ctx);
applyPatternsGreedily(func, std::move(patterns));
}

View File

@ -381,8 +381,7 @@ void PrepareTFPass::runOnFunction() {
// parameters from the TF Quant ops, thus this pattern should run with the
// first `applyPatternsGreedily` method, which would otherwise removes the
// TF FakeQuant ops by the constant folding.
patterns.push_back(
llvm::make_unique<InsertTFLQuantOpsAfterTFFakeQuantOp>(&getContext()));
patterns.insert<InsertTFLQuantOpsAfterTFFakeQuantOp>(&getContext());
TFL::populateWithGenerated(&getContext(), &patterns);
// TODO(karimnosseir): Split to separate pass probably after
// deciding on long term plan for this optimization.
@ -394,9 +393,8 @@ void PrepareTFPass::runOnFunction() {
// Load the generated pattern again, so new quantization pass-through
// will be applied.
TFL::populateWithGenerated(&getContext(), &patterns);
patterns.push_back(llvm::make_unique<ConvertTFConv2D>(&getContext()));
patterns.push_back(
llvm::make_unique<ConvertTFDepthwiseConv2dNative>(&getContext()));
patterns.insert<ConvertTFConv2D, ConvertTFDepthwiseConv2dNative>(
&getContext());
applyPatternsGreedily(func, std::move(patterns));
}

View File

@ -55,8 +55,8 @@ void QuantizePass::runOnFunction() {
auto func = getFunction();
auto* ctx = func.getContext();
TFL::populateWithGenerated(ctx, &patterns);
mlir::RewriteListBuilder<mlir::TFL::GenericFullQuantizationPattern<
mlir::TFL::QuantizeOp, mlir::TFL::DequantizeOp>>::build(patterns, ctx);
patterns.insert<mlir::TFL::GenericFullQuantizationPattern<
mlir::TFL::QuantizeOp, mlir::TFL::DequantizeOp>>(ctx);
applyPatternsGreedily(func, std::move(patterns));
}
} // namespace

View File

@ -106,7 +106,7 @@ namespace {
void AddOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
RewriteListBuilder<AddToAddV2>::build(results, context);
results.insert<AddToAddV2>(context);
}
//===----------------------------------------------------------------------===//
@ -115,7 +115,7 @@ void AddOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
void AddV2Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
RewriteListBuilder<AddV2OfNegLeft, AddV2OfNegRight>::build(results, context);
results.insert<AddV2OfNegLeft, AddV2OfNegRight>(context);
}
//===----------------------------------------------------------------------===//
@ -144,7 +144,7 @@ struct AssertWithTrue : public OpRewritePattern<AssertOp> {
void AssertOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
RewriteListBuilder<AssertWithTrue>::build(results, context);
results.insert<AssertWithTrue>(context);
}
//===----------------------------------------------------------------------===//
@ -153,7 +153,7 @@ void AssertOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
void BitcastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
RewriteListBuilder<BitcastSameType, BitcastNested>::build(results, context);
results.insert<BitcastSameType, BitcastNested>(context);
}
//===----------------------------------------------------------------------===//
@ -174,7 +174,7 @@ static LogicalResult Verify(BroadcastToOp op) {
void CastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
RewriteListBuilder<CastSameType>::build(results, context);
results.insert<CastSameType>(context);
}
//===----------------------------------------------------------------------===//
@ -183,7 +183,7 @@ void CastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
void ConjOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
RewriteListBuilder<ConjNested>::build(results, context);
results.insert<ConjNested>(context);
}
//===----------------------------------------------------------------------===//
@ -239,7 +239,7 @@ void ConstOp::build(Builder *builder, OperationState *result, Type type,
void DivOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
RewriteListBuilder<DivWithSqrtDivisor>::build(results, context);
results.insert<DivWithSqrtDivisor>(context);
}
//===----------------------------------------------------------------------===//
@ -415,7 +415,7 @@ static LogicalResult Verify(IfOp op) {
void InvertOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
RewriteListBuilder<InvertNested>::build(results, context);
results.insert<InvertNested>(context);
}
//===----------------------------------------------------------------------===//
@ -449,7 +449,7 @@ OpFoldResult LeakyReluOp::fold(ArrayRef<Attribute> operands) {
void LogOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
RewriteListBuilder<LogOfSoftmax>::build(results, context);
results.insert<LogOfSoftmax>(context);
}
//===----------------------------------------------------------------------===//
@ -458,10 +458,9 @@ void LogOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
void LogicalNotOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
RewriteListBuilder<LogicalNotNested, LogicalNotOfEqual, LogicalNotOfNotEqual,
LogicalNotOfGreater, LogicalNotOfGreaterEqual,
LogicalNotOfLess, LogicalNotOfLessEqual>::build(results,
context);
results.insert<LogicalNotNested, LogicalNotOfEqual, LogicalNotOfNotEqual,
LogicalNotOfGreater, LogicalNotOfGreaterEqual,
LogicalNotOfLess, LogicalNotOfLessEqual>(context);
}
//===----------------------------------------------------------------------===//
@ -470,7 +469,7 @@ void LogicalNotOp::getCanonicalizationPatterns(
void NegOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
RewriteListBuilder<NegNested>::build(results, context);
results.insert<NegNested>(context);
}
//===----------------------------------------------------------------------===//
@ -479,7 +478,7 @@ void NegOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
void ReciprocalOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
RewriteListBuilder<ReciprocalNested>::build(results, context);
results.insert<ReciprocalNested>(context);
}
//===----------------------------------------------------------------------===//
@ -538,7 +537,7 @@ void RankOp::build(Builder *builder, OperationState *result, Value *input) {
void RealDivOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
RewriteListBuilder<RealDivWithSqrtDivisor>::build(results, context);
results.insert<RealDivWithSqrtDivisor>(context);
}
//===----------------------------------------------------------------------===//
@ -698,7 +697,7 @@ static LogicalResult Verify(SoftmaxOp op) {
void SquareOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
RewriteListBuilder<SquareOfSub>::build(results, context);
results.insert<SquareOfSub>(context);
}
//===----------------------------------------------------------------------===//
@ -707,7 +706,7 @@ void SquareOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
void SubOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
RewriteListBuilder<SubOfNeg>::build(results, context);
results.insert<SubOfNeg>(context);
}
//===----------------------------------------------------------------------===//
@ -778,7 +777,7 @@ void TransposeOp::build(Builder *builder, OperationState *result, Value *x,
void TruncateDivOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
RewriteListBuilder<TruncateDivWithSqrtDivisor>::build(results, context);
results.insert<TruncateDivWithSqrtDivisor>(context);
}
//===----------------------------------------------------------------------===//
@ -867,7 +866,7 @@ static LogicalResult Verify(WhileOp op) {
void XdivyOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
RewriteListBuilder<XdivyWithSqrtDivisor>::build(results, context);
results.insert<XdivyWithSqrtDivisor>(context);
}
//===----------------------------------------------------------------------===//

View File

@ -133,10 +133,8 @@ void LegalizeToStandard::runOnFunction() {
auto func = getFunction();
mlir::XLA::populateWithGenerated(func.getContext(), &patterns);
patterns.push_back(
llvm::make_unique<mlir::XLA::CompareFConvert>(&getContext()));
patterns.push_back(
llvm::make_unique<mlir::XLA::CompareIConvert>(&getContext()));
patterns.insert<mlir::XLA::CompareFConvert, mlir::XLA::CompareIConvert>(
&getContext());
applyPatternsGreedily(func, std::move(patterns));
}

View File

@ -29,7 +29,7 @@ class MLIRContext;
class RewritePattern;
// Owning list of rewriting patterns.
using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>;
class OwningRewritePatternList;
/// Collect a set of patterns to lower from loop.for, loop.if, and
/// loop.terminator to CFG operations within the Standard dialect, in particular

View File

@ -38,7 +38,7 @@ class RewritePattern;
class Type;
// Owning list of rewriting patterns.
using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>;
class OwningRewritePatternList;
/// Type for a callback constructing the owning list of patterns for the
/// conversion to the LLVMIR dialect. The callback is expected to append

View File

@ -57,9 +57,7 @@ class Value;
/// either OpTy or OperandAdaptor<OpTy> seamlessly.
template <typename OpTy> using OperandAdaptor = typename OpTy::OperandAdaptor;
/// This is a vector that owns the patterns inside of it.
using OwningPatternList = std::vector<std::unique_ptr<Pattern>>;
using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>;
class OwningRewritePatternList;
enum class OperationProperty {
/// This bit is set for an operation if it is a commutative operation: that

View File

@ -394,8 +394,39 @@ private:
// Pattern-driven rewriters
//===----------------------------------------------------------------------===//
/// This is a vector that owns the patterns inside of it.
using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>;
class OwningRewritePatternList {
using PatternListT = std::vector<std::unique_ptr<RewritePattern>>;
public:
PatternListT::iterator begin() { return patterns.begin(); }
PatternListT::iterator end() { return patterns.end(); }
PatternListT::const_iterator begin() const { return patterns.begin(); }
PatternListT::const_iterator end() const { return patterns.end(); }
//===--------------------------------------------------------------------===//
// Pattern Insertion
//===--------------------------------------------------------------------===//
void insert(RewritePattern *pattern) { patterns.emplace_back(pattern); }
/// Add an instance of each of the pattern types 'Ts' to the pattern list with
/// the given arguments.
// Note: ConstructorArg is necessary here to separate the two variadic lists.
template <typename... Ts, typename ConstructorArg,
typename... ConstructorArgs>
void insert(ConstructorArg &&arg, ConstructorArgs &&... args) {
// The following expands a call to emplace_back for each of the pattern
// types 'Ts'. This magic is necessary due to a limitation in the places
// that a parameter pack can be expanded in c++11.
// FIXME: In c++17 this can be simplified by using 'fold expressions'.
using dummy = int[];
(void)dummy{
0, (patterns.emplace_back(llvm::make_unique<Ts>(arg, args...)), 0)...};
}
private:
PatternListT patterns;
};
/// This class manages optimization and execution of a group of rewrite
/// patterns, providing an API for finding and applying, the best match against
@ -404,7 +435,7 @@ using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>;
class RewritePatternMatcher {
public:
/// Create a RewritePatternMatcher with the specified set of patterns.
explicit RewritePatternMatcher(OwningRewritePatternList &&patterns);
explicit RewritePatternMatcher(OwningRewritePatternList &patterns);
/// Try to match the given operation to a pattern and rewrite it. Return
/// true if any pattern matches.
@ -416,7 +447,7 @@ private:
/// The group of patterns that are matched for optimization through this
/// matcher.
OwningRewritePatternList patterns;
std::vector<RewritePattern *> patterns;
};
/// Rewrite the regions of the specified operation, which must be isolated from
@ -427,29 +458,6 @@ private:
///
bool applyPatternsGreedily(Operation *op, OwningRewritePatternList &&patterns);
/// Helper class to create a list of rewrite patterns given a list of their
/// types and a list of attributes perfect-forwarded to each of the conversion
/// constructors.
template <typename Arg, typename... Args> struct RewriteListBuilder {
template <typename... ConstructorArgs>
static void build(OwningRewritePatternList &patterns,
ConstructorArgs &&... constructorArgs) {
RewriteListBuilder<Args...>::build(
patterns, std::forward<ConstructorArgs>(constructorArgs)...);
RewriteListBuilder<Arg>::build(
patterns, std::forward<ConstructorArgs>(constructorArgs)...);
}
};
// Template specialization to stop recursion.
template <typename Arg> struct RewriteListBuilder<Arg> {
template <typename... ConstructorArgs>
static void build(OwningRewritePatternList &patterns,
ConstructorArgs &&... constructorArgs) {
patterns.emplace_back(llvm::make_unique<Arg>(
std::forward<ConstructorArgs>(constructorArgs)...));
}
};
} // end namespace mlir
#endif // MLIR_PATTERN_MATCH_H

View File

@ -32,7 +32,7 @@ class RewritePattern;
class Value;
// Owning list of rewriting patterns.
using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>;
class OwningRewritePatternList;
/// Emit code that computes the given affine expression using standard
/// arithmetic operations applied to the provided dimension and symbol values.

View File

@ -708,7 +708,7 @@ struct SimplifyAffineApply : public OpRewritePattern<AffineApplyOp> {
void AffineApplyOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.push_back(llvm::make_unique<SimplifyAffineApply>(context));
results.insert<SimplifyAffineApply>(context);
}
//===----------------------------------------------------------------------===//
@ -912,8 +912,7 @@ LogicalResult AffineDmaStartOp::verify() {
void AffineDmaStartOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
/// dma_start(memrefcast) -> dma_start
results.push_back(
llvm::make_unique<MemRefCastFolder>(getOperationName(), context));
results.insert<MemRefCastFolder>(getOperationName(), context);
}
//===----------------------------------------------------------------------===//
@ -989,8 +988,7 @@ LogicalResult AffineDmaWaitOp::verify() {
void AffineDmaWaitOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
/// dma_wait(memrefcast) -> dma_wait
results.push_back(
llvm::make_unique<MemRefCastFolder>(getOperationName(), context));
results.insert<MemRefCastFolder>(getOperationName(), context);
}
//===----------------------------------------------------------------------===//
@ -1333,7 +1331,7 @@ struct AffineForLoopBoundFolder : public OpRewritePattern<AffineForOp> {
void AffineForOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.push_back(llvm::make_unique<AffineForLoopBoundFolder>(context));
results.insert<AffineForLoopBoundFolder>(context);
}
AffineBound AffineForOp::getLowerBound() {
@ -1659,8 +1657,7 @@ LogicalResult AffineLoadOp::verify() {
void AffineLoadOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
/// load(memrefcast) -> load
results.push_back(
llvm::make_unique<MemRefCastFolder>(getOperationName(), context));
results.insert<MemRefCastFolder>(getOperationName(), context);
}
//===----------------------------------------------------------------------===//
@ -1752,8 +1749,7 @@ LogicalResult AffineStoreOp::verify() {
void AffineStoreOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
/// load(memrefcast) -> load
results.push_back(
llvm::make_unique<MemRefCastFolder>(getOperationName(), context));
results.insert<MemRefCastFolder>(getOperationName(), context);
}
#define GET_OP_CLASSES

View File

@ -258,8 +258,7 @@ IfLowering::matchAndRewrite(IfOp ifOp, PatternRewriter &rewriter) const {
void mlir::populateLoopToStdConversionPatterns(
OwningRewritePatternList &patterns, MLIRContext *ctx) {
RewriteListBuilder<ForLowering, IfLowering, TerminatorLowering>::build(
patterns, ctx);
patterns.insert<ForLowering, IfLowering, TerminatorLowering>(ctx);
}
void ControlFlowToCFGPass::runOnFunction() {

View File

@ -104,8 +104,7 @@ void GPUToSPIRVPass::runOnModule() {
SPIRVTypeConverter typeConverter(context);
SPIRVEntryFnTypeConverter entryFnConverter(context);
OwningRewritePatternList patterns;
RewriteListBuilder<KernelFnConversion>::build(
patterns, context, typeConverter, entryFnConverter);
patterns.insert<KernelFnConversion>(context, typeConverter, entryFnConverter);
populateStandardToSPIRVPatterns(context, patterns);
ConversionTarget target(*context);

View File

@ -1023,7 +1023,7 @@ void mlir::LLVM::ensureDistinctSuccessors(ModuleOp m) {
void mlir::populateStdToLLVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
// FIXME: this should be tablegen'ed
RewriteListBuilder<
patterns.insert<
AddFOpLowering, AddIOpLowering, AndOpLowering, AllocOpLowering,
BranchOpLowering, CallIndirectOpLowering, CallOpLowering, CmpIOpLowering,
CondBranchOpLowering, ConstLLVMOpLowering, DeallocOpLowering,
@ -1032,8 +1032,7 @@ void mlir::populateStdToLLVMConversionPatterns(
MemRefCastOpLowering, MulFOpLowering, MulIOpLowering, OrOpLowering,
RemISOpLowering, RemIUOpLowering, RemFOpLowering, ReturnOpLowering,
SelectOpLowering, SIToFPLowering, StoreOpLowering, SubFOpLowering,
SubIOpLowering, XOrOpLowering>::build(patterns, *converter.getDialect(),
converter);
SubIOpLowering, XOrOpLowering>(*converter.getDialect(), converter);
}
// Convert types using the stored LLVM IR module.

View File

@ -201,6 +201,6 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
OwningRewritePatternList &patterns) {
populateWithGenerated(context, &patterns);
// Add the return op conversion.
RewriteListBuilder<ReturnToSPIRVConversion>::build(patterns, context);
patterns.insert<ReturnToSPIRVConversion>(context);
}
} // namespace mlir

View File

@ -368,8 +368,7 @@ void LowerUniformRealMathPass::runOnFunction() {
auto fn = getFunction();
OwningRewritePatternList patterns;
auto *context = &getContext();
patterns.push_back(llvm::make_unique<UniformRealAddEwPattern>(context));
patterns.push_back(llvm::make_unique<UniformRealMulEwPattern>(context));
patterns.insert<UniformRealAddEwPattern, UniformRealMulEwPattern>(context);
applyPatternsGreedily(fn, std::move(patterns));
}
@ -389,7 +388,7 @@ void LowerUniformCastsPass::runOnFunction() {
auto fn = getFunction();
OwningRewritePatternList patterns;
auto *context = &getContext();
patterns.push_back(llvm::make_unique<UniformDequantizePattern>(context));
patterns.insert<UniformDequantizePattern>(context);
applyPatternsGreedily(fn, std::move(patterns));
}

View File

@ -372,7 +372,7 @@ class PropagateConstantBounds : public OpRewritePattern<LaunchOp> {
void LaunchOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
RewriteListBuilder<PropagateConstantBounds>::build(results, context);
results.insert<PropagateConstantBounds>(context);
}
//===----------------------------------------------------------------------===//

View File

@ -60,8 +60,7 @@ public:
void StorageCastOp::getCanonicalizationPatterns(
OwningRewritePatternList &patterns, MLIRContext *context) {
patterns.push_back(
llvm::make_unique<RemoveRedundantStorageCastsRewrite>(context));
patterns.insert<RemoveRedundantStorageCastsRewrite>(context);
}
QuantizationDialect::QuantizationDialect(MLIRContext *context)

View File

@ -108,7 +108,7 @@ void ConvertConstPass::runOnFunction() {
OwningRewritePatternList patterns;
auto func = getFunction();
auto *context = &getContext();
patterns.push_back(llvm::make_unique<QuantizedConstRewrite>(context));
patterns.insert<QuantizedConstRewrite>(context);
applyPatternsGreedily(func, std::move(patterns));
}

View File

@ -97,8 +97,7 @@ void ConvertSimulatedQuantPass::runOnFunction() {
OwningRewritePatternList patterns;
auto func = getFunction();
auto *context = &getContext();
patterns.push_back(
llvm::make_unique<ConstFakeQuantRewrite>(context, &hadFailure));
patterns.insert<ConstFakeQuantRewrite>(context, &hadFailure);
applyPatternsGreedily(func, std::move(patterns));
if (hadFailure)
signalPassFailure();

View File

@ -149,12 +149,13 @@ void PatternRewriter::updatedRootInPlace(
//===----------------------------------------------------------------------===//
RewritePatternMatcher::RewritePatternMatcher(
OwningRewritePatternList &&patterns)
: patterns(std::move(patterns)) {
OwningRewritePatternList &patterns) {
for (auto &pattern : patterns)
this->patterns.push_back(pattern.get());
// Sort the patterns by benefit to simplify the matching logic.
std::stable_sort(this->patterns.begin(), this->patterns.end(),
[](const std::unique_ptr<RewritePattern> &l,
const std::unique_ptr<RewritePattern> &r) {
[](RewritePattern *l, RewritePattern *r) {
return r->getBenefit() < l->getBenefit();
});
}
@ -162,7 +163,7 @@ RewritePatternMatcher::RewritePatternMatcher(
/// Try to match the given operation to a pattern and rewrite it.
bool RewritePatternMatcher::matchAndRewrite(Operation *op,
PatternRewriter &rewriter) {
for (auto &pattern : patterns) {
for (auto *pattern : patterns) {
// Ignore patterns that are for the wrong root or are impossible to match.
if (pattern->getRootKind() != op->getName() ||
pattern->getBenefit().isImpossibleToMatch())

View File

@ -678,12 +678,11 @@ static void
populateLinalgToLLVMConversionPatterns(LinalgTypeConverter &converter,
OwningRewritePatternList &patterns,
MLIRContext *ctx) {
RewriteListBuilder<BufferAllocOpConversion, BufferDeallocOpConversion,
BufferSizeOpConversion, DimOpConversion,
LinalgOpConversion<DotOp>, LinalgOpConversion<MatmulOp>,
LoadOpConversion, RangeOpConversion, SliceOpConversion,
StoreOpConversion, ViewOpConversion>::build(patterns, ctx,
converter);
patterns.insert<BufferAllocOpConversion, BufferDeallocOpConversion,
BufferSizeOpConversion, DimOpConversion,
LinalgOpConversion<DotOp>, LinalgOpConversion<MatmulOp>,
LoadOpConversion, RangeOpConversion, SliceOpConversion,
StoreOpConversion, ViewOpConversion>(ctx, converter);
}
namespace {

View File

@ -60,12 +60,9 @@ void RemoveInstrumentationPass::runOnFunction() {
OwningRewritePatternList patterns;
auto func = getFunction();
auto *context = &getContext();
patterns.push_back(
llvm::make_unique<RemoveIdentityOpRewrite<StatisticsOp>>(context));
patterns.push_back(
llvm::make_unique<RemoveIdentityOpRewrite<StatisticsRefOp>>(context));
patterns.push_back(
llvm::make_unique<RemoveIdentityOpRewrite<CoupledRefOp>>(context));
patterns.insert<RemoveIdentityOpRewrite<StatisticsOp>,
RemoveIdentityOpRewrite<StatisticsRefOp>,
RemoveIdentityOpRewrite<CoupledRefOp>>(context);
applyPatternsGreedily(func, std::move(patterns));
}

View File

@ -365,8 +365,7 @@ struct SimplifyDeadAlloc : public OpRewritePattern<AllocOp> {
void AllocOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
RewriteListBuilder<SimplifyAllocConst, SimplifyDeadAlloc>::build(results,
context);
results.insert<SimplifyAllocConst, SimplifyDeadAlloc>(context);
}
//===----------------------------------------------------------------------===//
@ -544,8 +543,7 @@ static LogicalResult verify(CallIndirectOp op) {
void CallIndirectOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.push_back(
llvm::make_unique<SimplifyIndirectCallWithKnownCallee>(context));
results.insert<SimplifyIndirectCallWithKnownCallee>(context);
}
//===----------------------------------------------------------------------===//
@ -1015,7 +1013,7 @@ static void print(OpAsmPrinter *p, CondBranchOp op) {
void CondBranchOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.push_back(llvm::make_unique<SimplifyConstCondBranchPred>(context));
results.insert<SimplifyConstCondBranchPred>(context);
}
//===----------------------------------------------------------------------===//
@ -1231,9 +1229,8 @@ static LogicalResult verify(DeallocOp op) {
void DeallocOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
/// dealloc(memrefcast) -> dealloc
results.push_back(
llvm::make_unique<MemRefCastFolder>(getOperationName(), context));
results.push_back(llvm::make_unique<SimplifyDeadDealloc>(context));
results.insert<MemRefCastFolder>(getOperationName(), context);
results.insert<SimplifyDeadDealloc>(context);
}
//===----------------------------------------------------------------------===//
@ -1497,8 +1494,7 @@ LogicalResult DmaStartOp::verify() {
void DmaStartOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
/// dma_start(memrefcast) -> dma_start
results.push_back(
llvm::make_unique<MemRefCastFolder>(getOperationName(), context));
results.insert<MemRefCastFolder>(getOperationName(), context);
}
// ---------------------------------------------------------------------------
@ -1561,8 +1557,7 @@ ParseResult DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) {
void DmaWaitOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
/// dma_wait(memrefcast) -> dma_wait
results.push_back(
llvm::make_unique<MemRefCastFolder>(getOperationName(), context));
results.insert<MemRefCastFolder>(getOperationName(), context);
}
//===----------------------------------------------------------------------===//
@ -1695,8 +1690,7 @@ static LogicalResult verify(LoadOp op) {
void LoadOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
/// load(memrefcast) -> load
results.push_back(
llvm::make_unique<MemRefCastFolder>(getOperationName(), context));
results.insert<MemRefCastFolder>(getOperationName(), context);
}
//===----------------------------------------------------------------------===//
@ -2007,8 +2001,7 @@ static LogicalResult verify(StoreOp op) {
void StoreOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
/// store(memrefcast) -> store
results.push_back(
llvm::make_unique<MemRefCastFolder>(getOperationName(), context));
results.insert<MemRefCastFolder>(getOperationName(), context);
}
//===----------------------------------------------------------------------===//

View File

@ -1243,8 +1243,7 @@ struct FuncOpSignatureConversion : public ConversionPattern {
void mlir::populateFuncOpTypeConversionPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx,
TypeConverter &converter) {
RewriteListBuilder<FuncOpSignatureConversion>::build(patterns, ctx,
converter);
patterns.insert<FuncOpSignatureConversion>(ctx, converter);
}
/// This function converts the type signature of the given block, by invoking

View File

@ -507,10 +507,11 @@ public:
void mlir::populateAffineToStdConversionPatterns(
OwningRewritePatternList &patterns, MLIRContext *ctx) {
RewriteListBuilder<AffineApplyLowering, AffineDmaStartLowering,
AffineDmaWaitLowering, AffineLoadLowering,
AffineStoreLowering, AffineForLowering, AffineIfLowering,
AffineTerminatorLowering>::build(patterns, ctx);
patterns
.insert<AffineApplyLowering, AffineDmaStartLowering,
AffineDmaWaitLowering, AffineLoadLowering, AffineStoreLowering,
AffineForLowering, AffineIfLowering, AffineTerminatorLowering>(
ctx);
}
namespace {

View File

@ -365,12 +365,8 @@ struct LowerVectorTransfersPass
void runOnFunction() {
OwningRewritePatternList patterns;
auto *context = &getContext();
patterns.push_back(
llvm::make_unique<VectorTransferRewriter<VectorTransferReadOp>>(
context));
patterns.push_back(
llvm::make_unique<VectorTransferRewriter<VectorTransferWriteOp>>(
context));
patterns.insert<VectorTransferRewriter<VectorTransferReadOp>,
VectorTransferRewriter<VectorTransferWriteOp>>(context);
applyPatternsGreedily(getFunction(), std::move(patterns));
}
};

View File

@ -44,8 +44,8 @@ namespace {
class GreedyPatternRewriteDriver : public PatternRewriter {
public:
explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
OwningRewritePatternList &&patterns)
: PatternRewriter(ctx), matcher(std::move(patterns)) {
OwningRewritePatternList &patterns)
: PatternRewriter(ctx), matcher(patterns) {
worklist.reserve(64);
}
@ -224,7 +224,7 @@ bool mlir::applyPatternsGreedily(Operation *op,
if (!op->isKnownIsolatedFromAbove())
return false;
GreedyPatternRewriteDriver driver(op->getContext(), std::move(patterns));
GreedyPatternRewriteDriver driver(op->getContext(), patterns);
bool converged = driver.simplify(op, maxPatternMatchIterations);
LLVM_DEBUG(if (!converged) {
llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "

View File

@ -41,7 +41,7 @@ struct TestPatternDriver : public FunctionPass<TestPatternDriver> {
populateWithGenerated(&getContext(), &patterns);
// Verify named pattern is generated with expected name.
RewriteListBuilder<TestNamedPatternRule>::build(patterns, &getContext());
patterns.insert<TestNamedPatternRule>(&getContext());
applyPatternsGreedily(getFunction(), std::move(patterns));
}
@ -193,9 +193,9 @@ struct TestLegalizePatternDriver
TestTypeConverter converter;
mlir::OwningRewritePatternList patterns;
populateWithGenerated(&getContext(), &patterns);
RewriteListBuilder<TestRegionRewriteBlockMovement, TestRegionRewriteUndo,
TestDropOp, TestPassthroughInvalidOp,
TestSplitReturnType>::build(patterns, &getContext());
patterns.insert<TestRegionRewriteBlockMovement, TestRegionRewriteUndo,
TestDropOp, TestPassthroughInvalidOp, TestSplitReturnType>(
&getContext());
mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(),
converter);

View File

@ -133,7 +133,7 @@ static LogicalResult runMLIRPasses(ModuleOp m) {
pm.addPass(createConvertToLLVMIRPass([](LLVMTypeConverter &converter,
OwningRewritePatternList &patterns) {
populateStdToLLVMConversionPatterns(converter, patterns);
patterns.push_back(llvm::make_unique<GPULaunchFuncOpLowering>(converter));
patterns.insert<GPULaunchFuncOpLowering>(converter);
}));
pm.addPass(createLowerGpuOpsToNVVMOpsPass());
pm.addPass(createConvertGPUKernelToCubinPass(&compilePtxToCubin));

View File

@ -935,8 +935,7 @@ static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
os << "void populateWithGenerated(MLIRContext *context, "
<< "OwningRewritePatternList *patterns) {\n";
for (const auto &name : rewriterNames) {
os << " patterns->push_back(llvm::make_unique<" << name
<< ">(context));\n";
os << " patterns->insert<" << name << ">(context);\n";
}
os << "}\n";
}