NFC: Implement OwningRewritePatternList as a class instead of a using directive.
This allows for proper forward declaration, as opposed to leaking the internal implementation via a using directive. This also allows for all pattern building to go through 'insert' methods on the OwningRewritePatternList, replacing uses of 'push_back' and 'RewriteListBuilder'. PiperOrigin-RevId: 261816316
This commit is contained in:
parent
85a9058eff
commit
4add221b50
@ -506,7 +506,7 @@ OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
|
||||
|
||||
void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.push_back(llvm::make_unique<RemoveAdjacentReshape>(context));
|
||||
results.insert<RemoveAdjacentReshape>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -591,7 +591,7 @@ struct DropFakeQuant : public RewritePattern {
|
||||
|
||||
void FakeQuantOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.push_back(llvm::make_unique<DropFakeQuant>(context));
|
||||
results.insert<DropFakeQuant>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -205,9 +205,9 @@ void LegalizeTF::runOnFunction() {
|
||||
|
||||
// Add the generated patterns to the list.
|
||||
populateWithGenerated(ctx, &patterns);
|
||||
RewriteListBuilder<ConvertTFConcatOp, ConvertTFConcatV2Op, ConvertTFMatMulOp,
|
||||
patterns.insert<ConvertTFConcatOp, ConvertTFConcatV2Op, ConvertTFMatMulOp,
|
||||
ConvertTFPackOp, ConvertTFSplitOp, ConvertTFSplitVOp,
|
||||
ConvertTFUnpackOp>::build(patterns, ctx);
|
||||
ConvertTFUnpackOp>(ctx);
|
||||
applyPatternsGreedily(func, std::move(patterns));
|
||||
}
|
||||
|
||||
|
@ -240,8 +240,8 @@ void Optimize::runOnFunction() {
|
||||
auto func = getFunction();
|
||||
// Add the generated patterns to the list.
|
||||
TFL::populateWithGenerated(ctx, &patterns);
|
||||
RewriteListBuilder<FuseFullyConnectedAndAdd, FuseFullyConnectedAndRelu,
|
||||
PadStridedSliceDims>::build(patterns, ctx);
|
||||
patterns.insert<FuseFullyConnectedAndAdd, FuseFullyConnectedAndRelu,
|
||||
PadStridedSliceDims>(ctx);
|
||||
|
||||
applyPatternsGreedily(func, std::move(patterns));
|
||||
}
|
||||
|
@ -381,8 +381,7 @@ void PrepareTFPass::runOnFunction() {
|
||||
// parameters from the TF Quant ops, thus this pattern should run with the
|
||||
// first `applyPatternsGreedily` method, which would otherwise removes the
|
||||
// TF FakeQuant ops by the constant folding.
|
||||
patterns.push_back(
|
||||
llvm::make_unique<InsertTFLQuantOpsAfterTFFakeQuantOp>(&getContext()));
|
||||
patterns.insert<InsertTFLQuantOpsAfterTFFakeQuantOp>(&getContext());
|
||||
TFL::populateWithGenerated(&getContext(), &patterns);
|
||||
// TODO(karimnosseir): Split to separate pass probably after
|
||||
// deciding on long term plan for this optimization.
|
||||
@ -394,9 +393,8 @@ void PrepareTFPass::runOnFunction() {
|
||||
// Load the generated pattern again, so new quantization pass-through
|
||||
// will be applied.
|
||||
TFL::populateWithGenerated(&getContext(), &patterns);
|
||||
patterns.push_back(llvm::make_unique<ConvertTFConv2D>(&getContext()));
|
||||
patterns.push_back(
|
||||
llvm::make_unique<ConvertTFDepthwiseConv2dNative>(&getContext()));
|
||||
patterns.insert<ConvertTFConv2D, ConvertTFDepthwiseConv2dNative>(
|
||||
&getContext());
|
||||
applyPatternsGreedily(func, std::move(patterns));
|
||||
}
|
||||
|
||||
|
@ -55,8 +55,8 @@ void QuantizePass::runOnFunction() {
|
||||
auto func = getFunction();
|
||||
auto* ctx = func.getContext();
|
||||
TFL::populateWithGenerated(ctx, &patterns);
|
||||
mlir::RewriteListBuilder<mlir::TFL::GenericFullQuantizationPattern<
|
||||
mlir::TFL::QuantizeOp, mlir::TFL::DequantizeOp>>::build(patterns, ctx);
|
||||
patterns.insert<mlir::TFL::GenericFullQuantizationPattern<
|
||||
mlir::TFL::QuantizeOp, mlir::TFL::DequantizeOp>>(ctx);
|
||||
applyPatternsGreedily(func, std::move(patterns));
|
||||
}
|
||||
} // namespace
|
||||
|
@ -106,7 +106,7 @@ namespace {
|
||||
|
||||
void AddOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
RewriteListBuilder<AddToAddV2>::build(results, context);
|
||||
results.insert<AddToAddV2>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -115,7 +115,7 @@ void AddOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
|
||||
void AddV2Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
RewriteListBuilder<AddV2OfNegLeft, AddV2OfNegRight>::build(results, context);
|
||||
results.insert<AddV2OfNegLeft, AddV2OfNegRight>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -144,7 +144,7 @@ struct AssertWithTrue : public OpRewritePattern<AssertOp> {
|
||||
|
||||
void AssertOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
RewriteListBuilder<AssertWithTrue>::build(results, context);
|
||||
results.insert<AssertWithTrue>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -153,7 +153,7 @@ void AssertOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
|
||||
void BitcastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
RewriteListBuilder<BitcastSameType, BitcastNested>::build(results, context);
|
||||
results.insert<BitcastSameType, BitcastNested>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -174,7 +174,7 @@ static LogicalResult Verify(BroadcastToOp op) {
|
||||
|
||||
void CastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
RewriteListBuilder<CastSameType>::build(results, context);
|
||||
results.insert<CastSameType>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -183,7 +183,7 @@ void CastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
|
||||
void ConjOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
RewriteListBuilder<ConjNested>::build(results, context);
|
||||
results.insert<ConjNested>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -239,7 +239,7 @@ void ConstOp::build(Builder *builder, OperationState *result, Type type,
|
||||
|
||||
void DivOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
RewriteListBuilder<DivWithSqrtDivisor>::build(results, context);
|
||||
results.insert<DivWithSqrtDivisor>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -415,7 +415,7 @@ static LogicalResult Verify(IfOp op) {
|
||||
|
||||
void InvertOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
RewriteListBuilder<InvertNested>::build(results, context);
|
||||
results.insert<InvertNested>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -449,7 +449,7 @@ OpFoldResult LeakyReluOp::fold(ArrayRef<Attribute> operands) {
|
||||
|
||||
void LogOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
RewriteListBuilder<LogOfSoftmax>::build(results, context);
|
||||
results.insert<LogOfSoftmax>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -458,10 +458,9 @@ void LogOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
|
||||
void LogicalNotOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
RewriteListBuilder<LogicalNotNested, LogicalNotOfEqual, LogicalNotOfNotEqual,
|
||||
results.insert<LogicalNotNested, LogicalNotOfEqual, LogicalNotOfNotEqual,
|
||||
LogicalNotOfGreater, LogicalNotOfGreaterEqual,
|
||||
LogicalNotOfLess, LogicalNotOfLessEqual>::build(results,
|
||||
context);
|
||||
LogicalNotOfLess, LogicalNotOfLessEqual>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -470,7 +469,7 @@ void LogicalNotOp::getCanonicalizationPatterns(
|
||||
|
||||
void NegOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
RewriteListBuilder<NegNested>::build(results, context);
|
||||
results.insert<NegNested>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -479,7 +478,7 @@ void NegOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
|
||||
void ReciprocalOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
RewriteListBuilder<ReciprocalNested>::build(results, context);
|
||||
results.insert<ReciprocalNested>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -538,7 +537,7 @@ void RankOp::build(Builder *builder, OperationState *result, Value *input) {
|
||||
|
||||
void RealDivOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
RewriteListBuilder<RealDivWithSqrtDivisor>::build(results, context);
|
||||
results.insert<RealDivWithSqrtDivisor>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -698,7 +697,7 @@ static LogicalResult Verify(SoftmaxOp op) {
|
||||
|
||||
void SquareOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
RewriteListBuilder<SquareOfSub>::build(results, context);
|
||||
results.insert<SquareOfSub>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -707,7 +706,7 @@ void SquareOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
|
||||
void SubOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
RewriteListBuilder<SubOfNeg>::build(results, context);
|
||||
results.insert<SubOfNeg>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -778,7 +777,7 @@ void TransposeOp::build(Builder *builder, OperationState *result, Value *x,
|
||||
|
||||
void TruncateDivOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
RewriteListBuilder<TruncateDivWithSqrtDivisor>::build(results, context);
|
||||
results.insert<TruncateDivWithSqrtDivisor>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -867,7 +866,7 @@ static LogicalResult Verify(WhileOp op) {
|
||||
|
||||
void XdivyOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
RewriteListBuilder<XdivyWithSqrtDivisor>::build(results, context);
|
||||
results.insert<XdivyWithSqrtDivisor>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -133,10 +133,8 @@ void LegalizeToStandard::runOnFunction() {
|
||||
auto func = getFunction();
|
||||
|
||||
mlir::XLA::populateWithGenerated(func.getContext(), &patterns);
|
||||
patterns.push_back(
|
||||
llvm::make_unique<mlir::XLA::CompareFConvert>(&getContext()));
|
||||
patterns.push_back(
|
||||
llvm::make_unique<mlir::XLA::CompareIConvert>(&getContext()));
|
||||
patterns.insert<mlir::XLA::CompareFConvert, mlir::XLA::CompareIConvert>(
|
||||
&getContext());
|
||||
applyPatternsGreedily(func, std::move(patterns));
|
||||
}
|
||||
|
||||
|
@ -29,7 +29,7 @@ class MLIRContext;
|
||||
class RewritePattern;
|
||||
|
||||
// Owning list of rewriting patterns.
|
||||
using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>;
|
||||
class OwningRewritePatternList;
|
||||
|
||||
/// Collect a set of patterns to lower from loop.for, loop.if, and
|
||||
/// loop.terminator to CFG operations within the Standard dialect, in particular
|
||||
|
@ -38,7 +38,7 @@ class RewritePattern;
|
||||
class Type;
|
||||
|
||||
// Owning list of rewriting patterns.
|
||||
using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>;
|
||||
class OwningRewritePatternList;
|
||||
|
||||
/// Type for a callback constructing the owning list of patterns for the
|
||||
/// conversion to the LLVMIR dialect. The callback is expected to append
|
||||
|
@ -57,9 +57,7 @@ class Value;
|
||||
/// either OpTy or OperandAdaptor<OpTy> seamlessly.
|
||||
template <typename OpTy> using OperandAdaptor = typename OpTy::OperandAdaptor;
|
||||
|
||||
/// This is a vector that owns the patterns inside of it.
|
||||
using OwningPatternList = std::vector<std::unique_ptr<Pattern>>;
|
||||
using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>;
|
||||
class OwningRewritePatternList;
|
||||
|
||||
enum class OperationProperty {
|
||||
/// This bit is set for an operation if it is a commutative operation: that
|
||||
|
62
third_party/mlir/include/mlir/IR/PatternMatch.h
vendored
62
third_party/mlir/include/mlir/IR/PatternMatch.h
vendored
@ -394,8 +394,39 @@ private:
|
||||
// Pattern-driven rewriters
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// This is a vector that owns the patterns inside of it.
|
||||
using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>;
|
||||
class OwningRewritePatternList {
|
||||
using PatternListT = std::vector<std::unique_ptr<RewritePattern>>;
|
||||
|
||||
public:
|
||||
PatternListT::iterator begin() { return patterns.begin(); }
|
||||
PatternListT::iterator end() { return patterns.end(); }
|
||||
PatternListT::const_iterator begin() const { return patterns.begin(); }
|
||||
PatternListT::const_iterator end() const { return patterns.end(); }
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Pattern Insertion
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
void insert(RewritePattern *pattern) { patterns.emplace_back(pattern); }
|
||||
|
||||
/// Add an instance of each of the pattern types 'Ts' to the pattern list with
|
||||
/// the given arguments.
|
||||
// Note: ConstructorArg is necessary here to separate the two variadic lists.
|
||||
template <typename... Ts, typename ConstructorArg,
|
||||
typename... ConstructorArgs>
|
||||
void insert(ConstructorArg &&arg, ConstructorArgs &&... args) {
|
||||
// The following expands a call to emplace_back for each of the pattern
|
||||
// types 'Ts'. This magic is necessary due to a limitation in the places
|
||||
// that a parameter pack can be expanded in c++11.
|
||||
// FIXME: In c++17 this can be simplified by using 'fold expressions'.
|
||||
using dummy = int[];
|
||||
(void)dummy{
|
||||
0, (patterns.emplace_back(llvm::make_unique<Ts>(arg, args...)), 0)...};
|
||||
}
|
||||
|
||||
private:
|
||||
PatternListT patterns;
|
||||
};
|
||||
|
||||
/// This class manages optimization and execution of a group of rewrite
|
||||
/// patterns, providing an API for finding and applying, the best match against
|
||||
@ -404,7 +435,7 @@ using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>;
|
||||
class RewritePatternMatcher {
|
||||
public:
|
||||
/// Create a RewritePatternMatcher with the specified set of patterns.
|
||||
explicit RewritePatternMatcher(OwningRewritePatternList &&patterns);
|
||||
explicit RewritePatternMatcher(OwningRewritePatternList &patterns);
|
||||
|
||||
/// Try to match the given operation to a pattern and rewrite it. Return
|
||||
/// true if any pattern matches.
|
||||
@ -416,7 +447,7 @@ private:
|
||||
|
||||
/// The group of patterns that are matched for optimization through this
|
||||
/// matcher.
|
||||
OwningRewritePatternList patterns;
|
||||
std::vector<RewritePattern *> patterns;
|
||||
};
|
||||
|
||||
/// Rewrite the regions of the specified operation, which must be isolated from
|
||||
@ -427,29 +458,6 @@ private:
|
||||
///
|
||||
bool applyPatternsGreedily(Operation *op, OwningRewritePatternList &&patterns);
|
||||
|
||||
/// Helper class to create a list of rewrite patterns given a list of their
|
||||
/// types and a list of attributes perfect-forwarded to each of the conversion
|
||||
/// constructors.
|
||||
template <typename Arg, typename... Args> struct RewriteListBuilder {
|
||||
template <typename... ConstructorArgs>
|
||||
static void build(OwningRewritePatternList &patterns,
|
||||
ConstructorArgs &&... constructorArgs) {
|
||||
RewriteListBuilder<Args...>::build(
|
||||
patterns, std::forward<ConstructorArgs>(constructorArgs)...);
|
||||
RewriteListBuilder<Arg>::build(
|
||||
patterns, std::forward<ConstructorArgs>(constructorArgs)...);
|
||||
}
|
||||
};
|
||||
|
||||
// Template specialization to stop recursion.
|
||||
template <typename Arg> struct RewriteListBuilder<Arg> {
|
||||
template <typename... ConstructorArgs>
|
||||
static void build(OwningRewritePatternList &patterns,
|
||||
ConstructorArgs &&... constructorArgs) {
|
||||
patterns.emplace_back(llvm::make_unique<Arg>(
|
||||
std::forward<ConstructorArgs>(constructorArgs)...));
|
||||
}
|
||||
};
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_PATTERN_MATCH_H
|
||||
|
@ -32,7 +32,7 @@ class RewritePattern;
|
||||
class Value;
|
||||
|
||||
// Owning list of rewriting patterns.
|
||||
using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>;
|
||||
class OwningRewritePatternList;
|
||||
|
||||
/// Emit code that computes the given affine expression using standard
|
||||
/// arithmetic operations applied to the provided dimension and symbol values.
|
||||
|
16
third_party/mlir/lib/AffineOps/AffineOps.cpp
vendored
16
third_party/mlir/lib/AffineOps/AffineOps.cpp
vendored
@ -708,7 +708,7 @@ struct SimplifyAffineApply : public OpRewritePattern<AffineApplyOp> {
|
||||
|
||||
void AffineApplyOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
results.push_back(llvm::make_unique<SimplifyAffineApply>(context));
|
||||
results.insert<SimplifyAffineApply>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -912,8 +912,7 @@ LogicalResult AffineDmaStartOp::verify() {
|
||||
void AffineDmaStartOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
/// dma_start(memrefcast) -> dma_start
|
||||
results.push_back(
|
||||
llvm::make_unique<MemRefCastFolder>(getOperationName(), context));
|
||||
results.insert<MemRefCastFolder>(getOperationName(), context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -989,8 +988,7 @@ LogicalResult AffineDmaWaitOp::verify() {
|
||||
void AffineDmaWaitOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
/// dma_wait(memrefcast) -> dma_wait
|
||||
results.push_back(
|
||||
llvm::make_unique<MemRefCastFolder>(getOperationName(), context));
|
||||
results.insert<MemRefCastFolder>(getOperationName(), context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -1333,7 +1331,7 @@ struct AffineForLoopBoundFolder : public OpRewritePattern<AffineForOp> {
|
||||
|
||||
void AffineForOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.push_back(llvm::make_unique<AffineForLoopBoundFolder>(context));
|
||||
results.insert<AffineForLoopBoundFolder>(context);
|
||||
}
|
||||
|
||||
AffineBound AffineForOp::getLowerBound() {
|
||||
@ -1659,8 +1657,7 @@ LogicalResult AffineLoadOp::verify() {
|
||||
void AffineLoadOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
/// load(memrefcast) -> load
|
||||
results.push_back(
|
||||
llvm::make_unique<MemRefCastFolder>(getOperationName(), context));
|
||||
results.insert<MemRefCastFolder>(getOperationName(), context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -1752,8 +1749,7 @@ LogicalResult AffineStoreOp::verify() {
|
||||
void AffineStoreOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
/// load(memrefcast) -> load
|
||||
results.push_back(
|
||||
llvm::make_unique<MemRefCastFolder>(getOperationName(), context));
|
||||
results.insert<MemRefCastFolder>(getOperationName(), context);
|
||||
}
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
|
@ -258,8 +258,7 @@ IfLowering::matchAndRewrite(IfOp ifOp, PatternRewriter &rewriter) const {
|
||||
|
||||
void mlir::populateLoopToStdConversionPatterns(
|
||||
OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
||||
RewriteListBuilder<ForLowering, IfLowering, TerminatorLowering>::build(
|
||||
patterns, ctx);
|
||||
patterns.insert<ForLowering, IfLowering, TerminatorLowering>(ctx);
|
||||
}
|
||||
|
||||
void ControlFlowToCFGPass::runOnFunction() {
|
||||
|
@ -104,8 +104,7 @@ void GPUToSPIRVPass::runOnModule() {
|
||||
SPIRVTypeConverter typeConverter(context);
|
||||
SPIRVEntryFnTypeConverter entryFnConverter(context);
|
||||
OwningRewritePatternList patterns;
|
||||
RewriteListBuilder<KernelFnConversion>::build(
|
||||
patterns, context, typeConverter, entryFnConverter);
|
||||
patterns.insert<KernelFnConversion>(context, typeConverter, entryFnConverter);
|
||||
populateStandardToSPIRVPatterns(context, patterns);
|
||||
|
||||
ConversionTarget target(*context);
|
||||
|
@ -1023,7 +1023,7 @@ void mlir::LLVM::ensureDistinctSuccessors(ModuleOp m) {
|
||||
void mlir::populateStdToLLVMConversionPatterns(
|
||||
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
|
||||
// FIXME: this should be tablegen'ed
|
||||
RewriteListBuilder<
|
||||
patterns.insert<
|
||||
AddFOpLowering, AddIOpLowering, AndOpLowering, AllocOpLowering,
|
||||
BranchOpLowering, CallIndirectOpLowering, CallOpLowering, CmpIOpLowering,
|
||||
CondBranchOpLowering, ConstLLVMOpLowering, DeallocOpLowering,
|
||||
@ -1032,8 +1032,7 @@ void mlir::populateStdToLLVMConversionPatterns(
|
||||
MemRefCastOpLowering, MulFOpLowering, MulIOpLowering, OrOpLowering,
|
||||
RemISOpLowering, RemIUOpLowering, RemFOpLowering, ReturnOpLowering,
|
||||
SelectOpLowering, SIToFPLowering, StoreOpLowering, SubFOpLowering,
|
||||
SubIOpLowering, XOrOpLowering>::build(patterns, *converter.getDialect(),
|
||||
converter);
|
||||
SubIOpLowering, XOrOpLowering>(*converter.getDialect(), converter);
|
||||
}
|
||||
|
||||
// Convert types using the stored LLVM IR module.
|
||||
|
@ -201,6 +201,6 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
|
||||
OwningRewritePatternList &patterns) {
|
||||
populateWithGenerated(context, &patterns);
|
||||
// Add the return op conversion.
|
||||
RewriteListBuilder<ReturnToSPIRVConversion>::build(patterns, context);
|
||||
patterns.insert<ReturnToSPIRVConversion>(context);
|
||||
}
|
||||
} // namespace mlir
|
||||
|
@ -368,8 +368,7 @@ void LowerUniformRealMathPass::runOnFunction() {
|
||||
auto fn = getFunction();
|
||||
OwningRewritePatternList patterns;
|
||||
auto *context = &getContext();
|
||||
patterns.push_back(llvm::make_unique<UniformRealAddEwPattern>(context));
|
||||
patterns.push_back(llvm::make_unique<UniformRealMulEwPattern>(context));
|
||||
patterns.insert<UniformRealAddEwPattern, UniformRealMulEwPattern>(context);
|
||||
applyPatternsGreedily(fn, std::move(patterns));
|
||||
}
|
||||
|
||||
@ -389,7 +388,7 @@ void LowerUniformCastsPass::runOnFunction() {
|
||||
auto fn = getFunction();
|
||||
OwningRewritePatternList patterns;
|
||||
auto *context = &getContext();
|
||||
patterns.push_back(llvm::make_unique<UniformDequantizePattern>(context));
|
||||
patterns.insert<UniformDequantizePattern>(context);
|
||||
applyPatternsGreedily(fn, std::move(patterns));
|
||||
}
|
||||
|
||||
|
@ -372,7 +372,7 @@ class PropagateConstantBounds : public OpRewritePattern<LaunchOp> {
|
||||
|
||||
void LaunchOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
RewriteListBuilder<PropagateConstantBounds>::build(results, context);
|
||||
results.insert<PropagateConstantBounds>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -60,8 +60,7 @@ public:
|
||||
|
||||
void StorageCastOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &patterns, MLIRContext *context) {
|
||||
patterns.push_back(
|
||||
llvm::make_unique<RemoveRedundantStorageCastsRewrite>(context));
|
||||
patterns.insert<RemoveRedundantStorageCastsRewrite>(context);
|
||||
}
|
||||
|
||||
QuantizationDialect::QuantizationDialect(MLIRContext *context)
|
||||
|
@ -108,7 +108,7 @@ void ConvertConstPass::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
auto func = getFunction();
|
||||
auto *context = &getContext();
|
||||
patterns.push_back(llvm::make_unique<QuantizedConstRewrite>(context));
|
||||
patterns.insert<QuantizedConstRewrite>(context);
|
||||
applyPatternsGreedily(func, std::move(patterns));
|
||||
}
|
||||
|
||||
|
@ -97,8 +97,7 @@ void ConvertSimulatedQuantPass::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
auto func = getFunction();
|
||||
auto *context = &getContext();
|
||||
patterns.push_back(
|
||||
llvm::make_unique<ConstFakeQuantRewrite>(context, &hadFailure));
|
||||
patterns.insert<ConstFakeQuantRewrite>(context, &hadFailure);
|
||||
applyPatternsGreedily(func, std::move(patterns));
|
||||
if (hadFailure)
|
||||
signalPassFailure();
|
||||
|
11
third_party/mlir/lib/IR/PatternMatch.cpp
vendored
11
third_party/mlir/lib/IR/PatternMatch.cpp
vendored
@ -149,12 +149,13 @@ void PatternRewriter::updatedRootInPlace(
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
RewritePatternMatcher::RewritePatternMatcher(
|
||||
OwningRewritePatternList &&patterns)
|
||||
: patterns(std::move(patterns)) {
|
||||
OwningRewritePatternList &patterns) {
|
||||
for (auto &pattern : patterns)
|
||||
this->patterns.push_back(pattern.get());
|
||||
|
||||
// Sort the patterns by benefit to simplify the matching logic.
|
||||
std::stable_sort(this->patterns.begin(), this->patterns.end(),
|
||||
[](const std::unique_ptr<RewritePattern> &l,
|
||||
const std::unique_ptr<RewritePattern> &r) {
|
||||
[](RewritePattern *l, RewritePattern *r) {
|
||||
return r->getBenefit() < l->getBenefit();
|
||||
});
|
||||
}
|
||||
@ -162,7 +163,7 @@ RewritePatternMatcher::RewritePatternMatcher(
|
||||
/// Try to match the given operation to a pattern and rewrite it.
|
||||
bool RewritePatternMatcher::matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) {
|
||||
for (auto &pattern : patterns) {
|
||||
for (auto *pattern : patterns) {
|
||||
// Ignore patterns that are for the wrong root or are impossible to match.
|
||||
if (pattern->getRootKind() != op->getName() ||
|
||||
pattern->getBenefit().isImpossibleToMatch())
|
||||
|
@ -678,12 +678,11 @@ static void
|
||||
populateLinalgToLLVMConversionPatterns(LinalgTypeConverter &converter,
|
||||
OwningRewritePatternList &patterns,
|
||||
MLIRContext *ctx) {
|
||||
RewriteListBuilder<BufferAllocOpConversion, BufferDeallocOpConversion,
|
||||
patterns.insert<BufferAllocOpConversion, BufferDeallocOpConversion,
|
||||
BufferSizeOpConversion, DimOpConversion,
|
||||
LinalgOpConversion<DotOp>, LinalgOpConversion<MatmulOp>,
|
||||
LoadOpConversion, RangeOpConversion, SliceOpConversion,
|
||||
StoreOpConversion, ViewOpConversion>::build(patterns, ctx,
|
||||
converter);
|
||||
StoreOpConversion, ViewOpConversion>(ctx, converter);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -60,12 +60,9 @@ void RemoveInstrumentationPass::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
auto func = getFunction();
|
||||
auto *context = &getContext();
|
||||
patterns.push_back(
|
||||
llvm::make_unique<RemoveIdentityOpRewrite<StatisticsOp>>(context));
|
||||
patterns.push_back(
|
||||
llvm::make_unique<RemoveIdentityOpRewrite<StatisticsRefOp>>(context));
|
||||
patterns.push_back(
|
||||
llvm::make_unique<RemoveIdentityOpRewrite<CoupledRefOp>>(context));
|
||||
patterns.insert<RemoveIdentityOpRewrite<StatisticsOp>,
|
||||
RemoveIdentityOpRewrite<StatisticsRefOp>,
|
||||
RemoveIdentityOpRewrite<CoupledRefOp>>(context);
|
||||
applyPatternsGreedily(func, std::move(patterns));
|
||||
}
|
||||
|
||||
|
25
third_party/mlir/lib/StandardOps/Ops.cpp
vendored
25
third_party/mlir/lib/StandardOps/Ops.cpp
vendored
@ -365,8 +365,7 @@ struct SimplifyDeadAlloc : public OpRewritePattern<AllocOp> {
|
||||
|
||||
void AllocOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
RewriteListBuilder<SimplifyAllocConst, SimplifyDeadAlloc>::build(results,
|
||||
context);
|
||||
results.insert<SimplifyAllocConst, SimplifyDeadAlloc>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -544,8 +543,7 @@ static LogicalResult verify(CallIndirectOp op) {
|
||||
|
||||
void CallIndirectOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
results.push_back(
|
||||
llvm::make_unique<SimplifyIndirectCallWithKnownCallee>(context));
|
||||
results.insert<SimplifyIndirectCallWithKnownCallee>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -1015,7 +1013,7 @@ static void print(OpAsmPrinter *p, CondBranchOp op) {
|
||||
|
||||
void CondBranchOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
results.push_back(llvm::make_unique<SimplifyConstCondBranchPred>(context));
|
||||
results.insert<SimplifyConstCondBranchPred>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -1231,9 +1229,8 @@ static LogicalResult verify(DeallocOp op) {
|
||||
void DeallocOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
/// dealloc(memrefcast) -> dealloc
|
||||
results.push_back(
|
||||
llvm::make_unique<MemRefCastFolder>(getOperationName(), context));
|
||||
results.push_back(llvm::make_unique<SimplifyDeadDealloc>(context));
|
||||
results.insert<MemRefCastFolder>(getOperationName(), context);
|
||||
results.insert<SimplifyDeadDealloc>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -1497,8 +1494,7 @@ LogicalResult DmaStartOp::verify() {
|
||||
void DmaStartOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
/// dma_start(memrefcast) -> dma_start
|
||||
results.push_back(
|
||||
llvm::make_unique<MemRefCastFolder>(getOperationName(), context));
|
||||
results.insert<MemRefCastFolder>(getOperationName(), context);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@ -1561,8 +1557,7 @@ ParseResult DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||
void DmaWaitOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
/// dma_wait(memrefcast) -> dma_wait
|
||||
results.push_back(
|
||||
llvm::make_unique<MemRefCastFolder>(getOperationName(), context));
|
||||
results.insert<MemRefCastFolder>(getOperationName(), context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -1695,8 +1690,7 @@ static LogicalResult verify(LoadOp op) {
|
||||
void LoadOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
/// load(memrefcast) -> load
|
||||
results.push_back(
|
||||
llvm::make_unique<MemRefCastFolder>(getOperationName(), context));
|
||||
results.insert<MemRefCastFolder>(getOperationName(), context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -2007,8 +2001,7 @@ static LogicalResult verify(StoreOp op) {
|
||||
void StoreOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
/// store(memrefcast) -> store
|
||||
results.push_back(
|
||||
llvm::make_unique<MemRefCastFolder>(getOperationName(), context));
|
||||
results.insert<MemRefCastFolder>(getOperationName(), context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1243,8 +1243,7 @@ struct FuncOpSignatureConversion : public ConversionPattern {
|
||||
void mlir::populateFuncOpTypeConversionPattern(
|
||||
OwningRewritePatternList &patterns, MLIRContext *ctx,
|
||||
TypeConverter &converter) {
|
||||
RewriteListBuilder<FuncOpSignatureConversion>::build(patterns, ctx,
|
||||
converter);
|
||||
patterns.insert<FuncOpSignatureConversion>(ctx, converter);
|
||||
}
|
||||
|
||||
/// This function converts the type signature of the given block, by invoking
|
||||
|
@ -507,10 +507,11 @@ public:
|
||||
|
||||
void mlir::populateAffineToStdConversionPatterns(
|
||||
OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
||||
RewriteListBuilder<AffineApplyLowering, AffineDmaStartLowering,
|
||||
AffineDmaWaitLowering, AffineLoadLowering,
|
||||
AffineStoreLowering, AffineForLowering, AffineIfLowering,
|
||||
AffineTerminatorLowering>::build(patterns, ctx);
|
||||
patterns
|
||||
.insert<AffineApplyLowering, AffineDmaStartLowering,
|
||||
AffineDmaWaitLowering, AffineLoadLowering, AffineStoreLowering,
|
||||
AffineForLowering, AffineIfLowering, AffineTerminatorLowering>(
|
||||
ctx);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -365,12 +365,8 @@ struct LowerVectorTransfersPass
|
||||
void runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
auto *context = &getContext();
|
||||
patterns.push_back(
|
||||
llvm::make_unique<VectorTransferRewriter<VectorTransferReadOp>>(
|
||||
context));
|
||||
patterns.push_back(
|
||||
llvm::make_unique<VectorTransferRewriter<VectorTransferWriteOp>>(
|
||||
context));
|
||||
patterns.insert<VectorTransferRewriter<VectorTransferReadOp>,
|
||||
VectorTransferRewriter<VectorTransferWriteOp>>(context);
|
||||
applyPatternsGreedily(getFunction(), std::move(patterns));
|
||||
}
|
||||
};
|
||||
|
@ -44,8 +44,8 @@ namespace {
|
||||
class GreedyPatternRewriteDriver : public PatternRewriter {
|
||||
public:
|
||||
explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
|
||||
OwningRewritePatternList &&patterns)
|
||||
: PatternRewriter(ctx), matcher(std::move(patterns)) {
|
||||
OwningRewritePatternList &patterns)
|
||||
: PatternRewriter(ctx), matcher(patterns) {
|
||||
worklist.reserve(64);
|
||||
}
|
||||
|
||||
@ -224,7 +224,7 @@ bool mlir::applyPatternsGreedily(Operation *op,
|
||||
if (!op->isKnownIsolatedFromAbove())
|
||||
return false;
|
||||
|
||||
GreedyPatternRewriteDriver driver(op->getContext(), std::move(patterns));
|
||||
GreedyPatternRewriteDriver driver(op->getContext(), patterns);
|
||||
bool converged = driver.simplify(op, maxPatternMatchIterations);
|
||||
LLVM_DEBUG(if (!converged) {
|
||||
llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "
|
||||
|
@ -41,7 +41,7 @@ struct TestPatternDriver : public FunctionPass<TestPatternDriver> {
|
||||
populateWithGenerated(&getContext(), &patterns);
|
||||
|
||||
// Verify named pattern is generated with expected name.
|
||||
RewriteListBuilder<TestNamedPatternRule>::build(patterns, &getContext());
|
||||
patterns.insert<TestNamedPatternRule>(&getContext());
|
||||
|
||||
applyPatternsGreedily(getFunction(), std::move(patterns));
|
||||
}
|
||||
@ -193,9 +193,9 @@ struct TestLegalizePatternDriver
|
||||
TestTypeConverter converter;
|
||||
mlir::OwningRewritePatternList patterns;
|
||||
populateWithGenerated(&getContext(), &patterns);
|
||||
RewriteListBuilder<TestRegionRewriteBlockMovement, TestRegionRewriteUndo,
|
||||
TestDropOp, TestPassthroughInvalidOp,
|
||||
TestSplitReturnType>::build(patterns, &getContext());
|
||||
patterns.insert<TestRegionRewriteBlockMovement, TestRegionRewriteUndo,
|
||||
TestDropOp, TestPassthroughInvalidOp, TestSplitReturnType>(
|
||||
&getContext());
|
||||
mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(),
|
||||
converter);
|
||||
|
||||
|
@ -133,7 +133,7 @@ static LogicalResult runMLIRPasses(ModuleOp m) {
|
||||
pm.addPass(createConvertToLLVMIRPass([](LLVMTypeConverter &converter,
|
||||
OwningRewritePatternList &patterns) {
|
||||
populateStdToLLVMConversionPatterns(converter, patterns);
|
||||
patterns.push_back(llvm::make_unique<GPULaunchFuncOpLowering>(converter));
|
||||
patterns.insert<GPULaunchFuncOpLowering>(converter);
|
||||
}));
|
||||
pm.addPass(createLowerGpuOpsToNVVMOpsPass());
|
||||
pm.addPass(createConvertGPUKernelToCubinPass(&compilePtxToCubin));
|
||||
|
@ -935,8 +935,7 @@ static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
|
||||
os << "void populateWithGenerated(MLIRContext *context, "
|
||||
<< "OwningRewritePatternList *patterns) {\n";
|
||||
for (const auto &name : rewriterNames) {
|
||||
os << " patterns->push_back(llvm::make_unique<" << name
|
||||
<< ">(context));\n";
|
||||
os << " patterns->insert<" << name << ">(context);\n";
|
||||
}
|
||||
os << "}\n";
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user