NFC: Implement OwningRewritePatternList as a class instead of a using directive.
This allows for proper forward declaration, as opposed to leaking the internal implementation via a using directive. This also allows for all pattern building to go through 'insert' methods on the OwningRewritePatternList, replacing uses of 'push_back' and 'RewriteListBuilder'. PiperOrigin-RevId: 261816316
This commit is contained in:
parent
85a9058eff
commit
4add221b50
@ -506,7 +506,7 @@ OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
|
|||||||
|
|
||||||
void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
results.push_back(llvm::make_unique<RemoveAdjacentReshape>(context));
|
results.insert<RemoveAdjacentReshape>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -591,7 +591,7 @@ struct DropFakeQuant : public RewritePattern {
|
|||||||
|
|
||||||
void FakeQuantOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
void FakeQuantOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
results.push_back(llvm::make_unique<DropFakeQuant>(context));
|
results.insert<DropFakeQuant>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -205,9 +205,9 @@ void LegalizeTF::runOnFunction() {
|
|||||||
|
|
||||||
// Add the generated patterns to the list.
|
// Add the generated patterns to the list.
|
||||||
populateWithGenerated(ctx, &patterns);
|
populateWithGenerated(ctx, &patterns);
|
||||||
RewriteListBuilder<ConvertTFConcatOp, ConvertTFConcatV2Op, ConvertTFMatMulOp,
|
patterns.insert<ConvertTFConcatOp, ConvertTFConcatV2Op, ConvertTFMatMulOp,
|
||||||
ConvertTFPackOp, ConvertTFSplitOp, ConvertTFSplitVOp,
|
ConvertTFPackOp, ConvertTFSplitOp, ConvertTFSplitVOp,
|
||||||
ConvertTFUnpackOp>::build(patterns, ctx);
|
ConvertTFUnpackOp>(ctx);
|
||||||
applyPatternsGreedily(func, std::move(patterns));
|
applyPatternsGreedily(func, std::move(patterns));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -240,8 +240,8 @@ void Optimize::runOnFunction() {
|
|||||||
auto func = getFunction();
|
auto func = getFunction();
|
||||||
// Add the generated patterns to the list.
|
// Add the generated patterns to the list.
|
||||||
TFL::populateWithGenerated(ctx, &patterns);
|
TFL::populateWithGenerated(ctx, &patterns);
|
||||||
RewriteListBuilder<FuseFullyConnectedAndAdd, FuseFullyConnectedAndRelu,
|
patterns.insert<FuseFullyConnectedAndAdd, FuseFullyConnectedAndRelu,
|
||||||
PadStridedSliceDims>::build(patterns, ctx);
|
PadStridedSliceDims>(ctx);
|
||||||
|
|
||||||
applyPatternsGreedily(func, std::move(patterns));
|
applyPatternsGreedily(func, std::move(patterns));
|
||||||
}
|
}
|
||||||
|
@ -381,8 +381,7 @@ void PrepareTFPass::runOnFunction() {
|
|||||||
// parameters from the TF Quant ops, thus this pattern should run with the
|
// parameters from the TF Quant ops, thus this pattern should run with the
|
||||||
// first `applyPatternsGreedily` method, which would otherwise removes the
|
// first `applyPatternsGreedily` method, which would otherwise removes the
|
||||||
// TF FakeQuant ops by the constant folding.
|
// TF FakeQuant ops by the constant folding.
|
||||||
patterns.push_back(
|
patterns.insert<InsertTFLQuantOpsAfterTFFakeQuantOp>(&getContext());
|
||||||
llvm::make_unique<InsertTFLQuantOpsAfterTFFakeQuantOp>(&getContext()));
|
|
||||||
TFL::populateWithGenerated(&getContext(), &patterns);
|
TFL::populateWithGenerated(&getContext(), &patterns);
|
||||||
// TODO(karimnosseir): Split to separate pass probably after
|
// TODO(karimnosseir): Split to separate pass probably after
|
||||||
// deciding on long term plan for this optimization.
|
// deciding on long term plan for this optimization.
|
||||||
@ -394,9 +393,8 @@ void PrepareTFPass::runOnFunction() {
|
|||||||
// Load the generated pattern again, so new quantization pass-through
|
// Load the generated pattern again, so new quantization pass-through
|
||||||
// will be applied.
|
// will be applied.
|
||||||
TFL::populateWithGenerated(&getContext(), &patterns);
|
TFL::populateWithGenerated(&getContext(), &patterns);
|
||||||
patterns.push_back(llvm::make_unique<ConvertTFConv2D>(&getContext()));
|
patterns.insert<ConvertTFConv2D, ConvertTFDepthwiseConv2dNative>(
|
||||||
patterns.push_back(
|
&getContext());
|
||||||
llvm::make_unique<ConvertTFDepthwiseConv2dNative>(&getContext()));
|
|
||||||
applyPatternsGreedily(func, std::move(patterns));
|
applyPatternsGreedily(func, std::move(patterns));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -55,8 +55,8 @@ void QuantizePass::runOnFunction() {
|
|||||||
auto func = getFunction();
|
auto func = getFunction();
|
||||||
auto* ctx = func.getContext();
|
auto* ctx = func.getContext();
|
||||||
TFL::populateWithGenerated(ctx, &patterns);
|
TFL::populateWithGenerated(ctx, &patterns);
|
||||||
mlir::RewriteListBuilder<mlir::TFL::GenericFullQuantizationPattern<
|
patterns.insert<mlir::TFL::GenericFullQuantizationPattern<
|
||||||
mlir::TFL::QuantizeOp, mlir::TFL::DequantizeOp>>::build(patterns, ctx);
|
mlir::TFL::QuantizeOp, mlir::TFL::DequantizeOp>>(ctx);
|
||||||
applyPatternsGreedily(func, std::move(patterns));
|
applyPatternsGreedily(func, std::move(patterns));
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -106,7 +106,7 @@ namespace {
|
|||||||
|
|
||||||
void AddOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
void AddOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
RewriteListBuilder<AddToAddV2>::build(results, context);
|
results.insert<AddToAddV2>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -115,7 +115,7 @@ void AddOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
|||||||
|
|
||||||
void AddV2Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
void AddV2Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
RewriteListBuilder<AddV2OfNegLeft, AddV2OfNegRight>::build(results, context);
|
results.insert<AddV2OfNegLeft, AddV2OfNegRight>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -144,7 +144,7 @@ struct AssertWithTrue : public OpRewritePattern<AssertOp> {
|
|||||||
|
|
||||||
void AssertOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
void AssertOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
RewriteListBuilder<AssertWithTrue>::build(results, context);
|
results.insert<AssertWithTrue>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -153,7 +153,7 @@ void AssertOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
|||||||
|
|
||||||
void BitcastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
void BitcastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
RewriteListBuilder<BitcastSameType, BitcastNested>::build(results, context);
|
results.insert<BitcastSameType, BitcastNested>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -174,7 +174,7 @@ static LogicalResult Verify(BroadcastToOp op) {
|
|||||||
|
|
||||||
void CastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
void CastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
RewriteListBuilder<CastSameType>::build(results, context);
|
results.insert<CastSameType>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -183,7 +183,7 @@ void CastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
|||||||
|
|
||||||
void ConjOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
void ConjOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
RewriteListBuilder<ConjNested>::build(results, context);
|
results.insert<ConjNested>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -239,7 +239,7 @@ void ConstOp::build(Builder *builder, OperationState *result, Type type,
|
|||||||
|
|
||||||
void DivOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
void DivOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
RewriteListBuilder<DivWithSqrtDivisor>::build(results, context);
|
results.insert<DivWithSqrtDivisor>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -415,7 +415,7 @@ static LogicalResult Verify(IfOp op) {
|
|||||||
|
|
||||||
void InvertOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
void InvertOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
RewriteListBuilder<InvertNested>::build(results, context);
|
results.insert<InvertNested>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -449,7 +449,7 @@ OpFoldResult LeakyReluOp::fold(ArrayRef<Attribute> operands) {
|
|||||||
|
|
||||||
void LogOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
void LogOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
RewriteListBuilder<LogOfSoftmax>::build(results, context);
|
results.insert<LogOfSoftmax>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -458,10 +458,9 @@ void LogOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
|||||||
|
|
||||||
void LogicalNotOp::getCanonicalizationPatterns(
|
void LogicalNotOp::getCanonicalizationPatterns(
|
||||||
OwningRewritePatternList &results, MLIRContext *context) {
|
OwningRewritePatternList &results, MLIRContext *context) {
|
||||||
RewriteListBuilder<LogicalNotNested, LogicalNotOfEqual, LogicalNotOfNotEqual,
|
results.insert<LogicalNotNested, LogicalNotOfEqual, LogicalNotOfNotEqual,
|
||||||
LogicalNotOfGreater, LogicalNotOfGreaterEqual,
|
LogicalNotOfGreater, LogicalNotOfGreaterEqual,
|
||||||
LogicalNotOfLess, LogicalNotOfLessEqual>::build(results,
|
LogicalNotOfLess, LogicalNotOfLessEqual>(context);
|
||||||
context);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -470,7 +469,7 @@ void LogicalNotOp::getCanonicalizationPatterns(
|
|||||||
|
|
||||||
void NegOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
void NegOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
RewriteListBuilder<NegNested>::build(results, context);
|
results.insert<NegNested>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -479,7 +478,7 @@ void NegOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
|||||||
|
|
||||||
void ReciprocalOp::getCanonicalizationPatterns(
|
void ReciprocalOp::getCanonicalizationPatterns(
|
||||||
OwningRewritePatternList &results, MLIRContext *context) {
|
OwningRewritePatternList &results, MLIRContext *context) {
|
||||||
RewriteListBuilder<ReciprocalNested>::build(results, context);
|
results.insert<ReciprocalNested>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -538,7 +537,7 @@ void RankOp::build(Builder *builder, OperationState *result, Value *input) {
|
|||||||
|
|
||||||
void RealDivOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
void RealDivOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
RewriteListBuilder<RealDivWithSqrtDivisor>::build(results, context);
|
results.insert<RealDivWithSqrtDivisor>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -698,7 +697,7 @@ static LogicalResult Verify(SoftmaxOp op) {
|
|||||||
|
|
||||||
void SquareOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
void SquareOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
RewriteListBuilder<SquareOfSub>::build(results, context);
|
results.insert<SquareOfSub>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -707,7 +706,7 @@ void SquareOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
|||||||
|
|
||||||
void SubOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
void SubOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
RewriteListBuilder<SubOfNeg>::build(results, context);
|
results.insert<SubOfNeg>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -778,7 +777,7 @@ void TransposeOp::build(Builder *builder, OperationState *result, Value *x,
|
|||||||
|
|
||||||
void TruncateDivOp::getCanonicalizationPatterns(
|
void TruncateDivOp::getCanonicalizationPatterns(
|
||||||
OwningRewritePatternList &results, MLIRContext *context) {
|
OwningRewritePatternList &results, MLIRContext *context) {
|
||||||
RewriteListBuilder<TruncateDivWithSqrtDivisor>::build(results, context);
|
results.insert<TruncateDivWithSqrtDivisor>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -867,7 +866,7 @@ static LogicalResult Verify(WhileOp op) {
|
|||||||
|
|
||||||
void XdivyOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
void XdivyOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
RewriteListBuilder<XdivyWithSqrtDivisor>::build(results, context);
|
results.insert<XdivyWithSqrtDivisor>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -133,10 +133,8 @@ void LegalizeToStandard::runOnFunction() {
|
|||||||
auto func = getFunction();
|
auto func = getFunction();
|
||||||
|
|
||||||
mlir::XLA::populateWithGenerated(func.getContext(), &patterns);
|
mlir::XLA::populateWithGenerated(func.getContext(), &patterns);
|
||||||
patterns.push_back(
|
patterns.insert<mlir::XLA::CompareFConvert, mlir::XLA::CompareIConvert>(
|
||||||
llvm::make_unique<mlir::XLA::CompareFConvert>(&getContext()));
|
&getContext());
|
||||||
patterns.push_back(
|
|
||||||
llvm::make_unique<mlir::XLA::CompareIConvert>(&getContext()));
|
|
||||||
applyPatternsGreedily(func, std::move(patterns));
|
applyPatternsGreedily(func, std::move(patterns));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -29,7 +29,7 @@ class MLIRContext;
|
|||||||
class RewritePattern;
|
class RewritePattern;
|
||||||
|
|
||||||
// Owning list of rewriting patterns.
|
// Owning list of rewriting patterns.
|
||||||
using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>;
|
class OwningRewritePatternList;
|
||||||
|
|
||||||
/// Collect a set of patterns to lower from loop.for, loop.if, and
|
/// Collect a set of patterns to lower from loop.for, loop.if, and
|
||||||
/// loop.terminator to CFG operations within the Standard dialect, in particular
|
/// loop.terminator to CFG operations within the Standard dialect, in particular
|
||||||
|
@ -38,7 +38,7 @@ class RewritePattern;
|
|||||||
class Type;
|
class Type;
|
||||||
|
|
||||||
// Owning list of rewriting patterns.
|
// Owning list of rewriting patterns.
|
||||||
using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>;
|
class OwningRewritePatternList;
|
||||||
|
|
||||||
/// Type for a callback constructing the owning list of patterns for the
|
/// Type for a callback constructing the owning list of patterns for the
|
||||||
/// conversion to the LLVMIR dialect. The callback is expected to append
|
/// conversion to the LLVMIR dialect. The callback is expected to append
|
||||||
|
@ -57,9 +57,7 @@ class Value;
|
|||||||
/// either OpTy or OperandAdaptor<OpTy> seamlessly.
|
/// either OpTy or OperandAdaptor<OpTy> seamlessly.
|
||||||
template <typename OpTy> using OperandAdaptor = typename OpTy::OperandAdaptor;
|
template <typename OpTy> using OperandAdaptor = typename OpTy::OperandAdaptor;
|
||||||
|
|
||||||
/// This is a vector that owns the patterns inside of it.
|
class OwningRewritePatternList;
|
||||||
using OwningPatternList = std::vector<std::unique_ptr<Pattern>>;
|
|
||||||
using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>;
|
|
||||||
|
|
||||||
enum class OperationProperty {
|
enum class OperationProperty {
|
||||||
/// This bit is set for an operation if it is a commutative operation: that
|
/// This bit is set for an operation if it is a commutative operation: that
|
||||||
|
62
third_party/mlir/include/mlir/IR/PatternMatch.h
vendored
62
third_party/mlir/include/mlir/IR/PatternMatch.h
vendored
@ -394,8 +394,39 @@ private:
|
|||||||
// Pattern-driven rewriters
|
// Pattern-driven rewriters
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
/// This is a vector that owns the patterns inside of it.
|
class OwningRewritePatternList {
|
||||||
using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>;
|
using PatternListT = std::vector<std::unique_ptr<RewritePattern>>;
|
||||||
|
|
||||||
|
public:
|
||||||
|
PatternListT::iterator begin() { return patterns.begin(); }
|
||||||
|
PatternListT::iterator end() { return patterns.end(); }
|
||||||
|
PatternListT::const_iterator begin() const { return patterns.begin(); }
|
||||||
|
PatternListT::const_iterator end() const { return patterns.end(); }
|
||||||
|
|
||||||
|
//===--------------------------------------------------------------------===//
|
||||||
|
// Pattern Insertion
|
||||||
|
//===--------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
void insert(RewritePattern *pattern) { patterns.emplace_back(pattern); }
|
||||||
|
|
||||||
|
/// Add an instance of each of the pattern types 'Ts' to the pattern list with
|
||||||
|
/// the given arguments.
|
||||||
|
// Note: ConstructorArg is necessary here to separate the two variadic lists.
|
||||||
|
template <typename... Ts, typename ConstructorArg,
|
||||||
|
typename... ConstructorArgs>
|
||||||
|
void insert(ConstructorArg &&arg, ConstructorArgs &&... args) {
|
||||||
|
// The following expands a call to emplace_back for each of the pattern
|
||||||
|
// types 'Ts'. This magic is necessary due to a limitation in the places
|
||||||
|
// that a parameter pack can be expanded in c++11.
|
||||||
|
// FIXME: In c++17 this can be simplified by using 'fold expressions'.
|
||||||
|
using dummy = int[];
|
||||||
|
(void)dummy{
|
||||||
|
0, (patterns.emplace_back(llvm::make_unique<Ts>(arg, args...)), 0)...};
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
PatternListT patterns;
|
||||||
|
};
|
||||||
|
|
||||||
/// This class manages optimization and execution of a group of rewrite
|
/// This class manages optimization and execution of a group of rewrite
|
||||||
/// patterns, providing an API for finding and applying, the best match against
|
/// patterns, providing an API for finding and applying, the best match against
|
||||||
@ -404,7 +435,7 @@ using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>;
|
|||||||
class RewritePatternMatcher {
|
class RewritePatternMatcher {
|
||||||
public:
|
public:
|
||||||
/// Create a RewritePatternMatcher with the specified set of patterns.
|
/// Create a RewritePatternMatcher with the specified set of patterns.
|
||||||
explicit RewritePatternMatcher(OwningRewritePatternList &&patterns);
|
explicit RewritePatternMatcher(OwningRewritePatternList &patterns);
|
||||||
|
|
||||||
/// Try to match the given operation to a pattern and rewrite it. Return
|
/// Try to match the given operation to a pattern and rewrite it. Return
|
||||||
/// true if any pattern matches.
|
/// true if any pattern matches.
|
||||||
@ -416,7 +447,7 @@ private:
|
|||||||
|
|
||||||
/// The group of patterns that are matched for optimization through this
|
/// The group of patterns that are matched for optimization through this
|
||||||
/// matcher.
|
/// matcher.
|
||||||
OwningRewritePatternList patterns;
|
std::vector<RewritePattern *> patterns;
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Rewrite the regions of the specified operation, which must be isolated from
|
/// Rewrite the regions of the specified operation, which must be isolated from
|
||||||
@ -427,29 +458,6 @@ private:
|
|||||||
///
|
///
|
||||||
bool applyPatternsGreedily(Operation *op, OwningRewritePatternList &&patterns);
|
bool applyPatternsGreedily(Operation *op, OwningRewritePatternList &&patterns);
|
||||||
|
|
||||||
/// Helper class to create a list of rewrite patterns given a list of their
|
|
||||||
/// types and a list of attributes perfect-forwarded to each of the conversion
|
|
||||||
/// constructors.
|
|
||||||
template <typename Arg, typename... Args> struct RewriteListBuilder {
|
|
||||||
template <typename... ConstructorArgs>
|
|
||||||
static void build(OwningRewritePatternList &patterns,
|
|
||||||
ConstructorArgs &&... constructorArgs) {
|
|
||||||
RewriteListBuilder<Args...>::build(
|
|
||||||
patterns, std::forward<ConstructorArgs>(constructorArgs)...);
|
|
||||||
RewriteListBuilder<Arg>::build(
|
|
||||||
patterns, std::forward<ConstructorArgs>(constructorArgs)...);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Template specialization to stop recursion.
|
|
||||||
template <typename Arg> struct RewriteListBuilder<Arg> {
|
|
||||||
template <typename... ConstructorArgs>
|
|
||||||
static void build(OwningRewritePatternList &patterns,
|
|
||||||
ConstructorArgs &&... constructorArgs) {
|
|
||||||
patterns.emplace_back(llvm::make_unique<Arg>(
|
|
||||||
std::forward<ConstructorArgs>(constructorArgs)...));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // end namespace mlir
|
} // end namespace mlir
|
||||||
|
|
||||||
#endif // MLIR_PATTERN_MATCH_H
|
#endif // MLIR_PATTERN_MATCH_H
|
||||||
|
@ -32,7 +32,7 @@ class RewritePattern;
|
|||||||
class Value;
|
class Value;
|
||||||
|
|
||||||
// Owning list of rewriting patterns.
|
// Owning list of rewriting patterns.
|
||||||
using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>;
|
class OwningRewritePatternList;
|
||||||
|
|
||||||
/// Emit code that computes the given affine expression using standard
|
/// Emit code that computes the given affine expression using standard
|
||||||
/// arithmetic operations applied to the provided dimension and symbol values.
|
/// arithmetic operations applied to the provided dimension and symbol values.
|
||||||
|
16
third_party/mlir/lib/AffineOps/AffineOps.cpp
vendored
16
third_party/mlir/lib/AffineOps/AffineOps.cpp
vendored
@ -708,7 +708,7 @@ struct SimplifyAffineApply : public OpRewritePattern<AffineApplyOp> {
|
|||||||
|
|
||||||
void AffineApplyOp::getCanonicalizationPatterns(
|
void AffineApplyOp::getCanonicalizationPatterns(
|
||||||
OwningRewritePatternList &results, MLIRContext *context) {
|
OwningRewritePatternList &results, MLIRContext *context) {
|
||||||
results.push_back(llvm::make_unique<SimplifyAffineApply>(context));
|
results.insert<SimplifyAffineApply>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -912,8 +912,7 @@ LogicalResult AffineDmaStartOp::verify() {
|
|||||||
void AffineDmaStartOp::getCanonicalizationPatterns(
|
void AffineDmaStartOp::getCanonicalizationPatterns(
|
||||||
OwningRewritePatternList &results, MLIRContext *context) {
|
OwningRewritePatternList &results, MLIRContext *context) {
|
||||||
/// dma_start(memrefcast) -> dma_start
|
/// dma_start(memrefcast) -> dma_start
|
||||||
results.push_back(
|
results.insert<MemRefCastFolder>(getOperationName(), context);
|
||||||
llvm::make_unique<MemRefCastFolder>(getOperationName(), context));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -989,8 +988,7 @@ LogicalResult AffineDmaWaitOp::verify() {
|
|||||||
void AffineDmaWaitOp::getCanonicalizationPatterns(
|
void AffineDmaWaitOp::getCanonicalizationPatterns(
|
||||||
OwningRewritePatternList &results, MLIRContext *context) {
|
OwningRewritePatternList &results, MLIRContext *context) {
|
||||||
/// dma_wait(memrefcast) -> dma_wait
|
/// dma_wait(memrefcast) -> dma_wait
|
||||||
results.push_back(
|
results.insert<MemRefCastFolder>(getOperationName(), context);
|
||||||
llvm::make_unique<MemRefCastFolder>(getOperationName(), context));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -1333,7 +1331,7 @@ struct AffineForLoopBoundFolder : public OpRewritePattern<AffineForOp> {
|
|||||||
|
|
||||||
void AffineForOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
void AffineForOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
results.push_back(llvm::make_unique<AffineForLoopBoundFolder>(context));
|
results.insert<AffineForLoopBoundFolder>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
AffineBound AffineForOp::getLowerBound() {
|
AffineBound AffineForOp::getLowerBound() {
|
||||||
@ -1659,8 +1657,7 @@ LogicalResult AffineLoadOp::verify() {
|
|||||||
void AffineLoadOp::getCanonicalizationPatterns(
|
void AffineLoadOp::getCanonicalizationPatterns(
|
||||||
OwningRewritePatternList &results, MLIRContext *context) {
|
OwningRewritePatternList &results, MLIRContext *context) {
|
||||||
/// load(memrefcast) -> load
|
/// load(memrefcast) -> load
|
||||||
results.push_back(
|
results.insert<MemRefCastFolder>(getOperationName(), context);
|
||||||
llvm::make_unique<MemRefCastFolder>(getOperationName(), context));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -1752,8 +1749,7 @@ LogicalResult AffineStoreOp::verify() {
|
|||||||
void AffineStoreOp::getCanonicalizationPatterns(
|
void AffineStoreOp::getCanonicalizationPatterns(
|
||||||
OwningRewritePatternList &results, MLIRContext *context) {
|
OwningRewritePatternList &results, MLIRContext *context) {
|
||||||
/// load(memrefcast) -> load
|
/// load(memrefcast) -> load
|
||||||
results.push_back(
|
results.insert<MemRefCastFolder>(getOperationName(), context);
|
||||||
llvm::make_unique<MemRefCastFolder>(getOperationName(), context));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
|
@ -258,8 +258,7 @@ IfLowering::matchAndRewrite(IfOp ifOp, PatternRewriter &rewriter) const {
|
|||||||
|
|
||||||
void mlir::populateLoopToStdConversionPatterns(
|
void mlir::populateLoopToStdConversionPatterns(
|
||||||
OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
||||||
RewriteListBuilder<ForLowering, IfLowering, TerminatorLowering>::build(
|
patterns.insert<ForLowering, IfLowering, TerminatorLowering>(ctx);
|
||||||
patterns, ctx);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void ControlFlowToCFGPass::runOnFunction() {
|
void ControlFlowToCFGPass::runOnFunction() {
|
||||||
|
@ -104,8 +104,7 @@ void GPUToSPIRVPass::runOnModule() {
|
|||||||
SPIRVTypeConverter typeConverter(context);
|
SPIRVTypeConverter typeConverter(context);
|
||||||
SPIRVEntryFnTypeConverter entryFnConverter(context);
|
SPIRVEntryFnTypeConverter entryFnConverter(context);
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
RewriteListBuilder<KernelFnConversion>::build(
|
patterns.insert<KernelFnConversion>(context, typeConverter, entryFnConverter);
|
||||||
patterns, context, typeConverter, entryFnConverter);
|
|
||||||
populateStandardToSPIRVPatterns(context, patterns);
|
populateStandardToSPIRVPatterns(context, patterns);
|
||||||
|
|
||||||
ConversionTarget target(*context);
|
ConversionTarget target(*context);
|
||||||
|
@ -1023,7 +1023,7 @@ void mlir::LLVM::ensureDistinctSuccessors(ModuleOp m) {
|
|||||||
void mlir::populateStdToLLVMConversionPatterns(
|
void mlir::populateStdToLLVMConversionPatterns(
|
||||||
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
|
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
|
||||||
// FIXME: this should be tablegen'ed
|
// FIXME: this should be tablegen'ed
|
||||||
RewriteListBuilder<
|
patterns.insert<
|
||||||
AddFOpLowering, AddIOpLowering, AndOpLowering, AllocOpLowering,
|
AddFOpLowering, AddIOpLowering, AndOpLowering, AllocOpLowering,
|
||||||
BranchOpLowering, CallIndirectOpLowering, CallOpLowering, CmpIOpLowering,
|
BranchOpLowering, CallIndirectOpLowering, CallOpLowering, CmpIOpLowering,
|
||||||
CondBranchOpLowering, ConstLLVMOpLowering, DeallocOpLowering,
|
CondBranchOpLowering, ConstLLVMOpLowering, DeallocOpLowering,
|
||||||
@ -1032,8 +1032,7 @@ void mlir::populateStdToLLVMConversionPatterns(
|
|||||||
MemRefCastOpLowering, MulFOpLowering, MulIOpLowering, OrOpLowering,
|
MemRefCastOpLowering, MulFOpLowering, MulIOpLowering, OrOpLowering,
|
||||||
RemISOpLowering, RemIUOpLowering, RemFOpLowering, ReturnOpLowering,
|
RemISOpLowering, RemIUOpLowering, RemFOpLowering, ReturnOpLowering,
|
||||||
SelectOpLowering, SIToFPLowering, StoreOpLowering, SubFOpLowering,
|
SelectOpLowering, SIToFPLowering, StoreOpLowering, SubFOpLowering,
|
||||||
SubIOpLowering, XOrOpLowering>::build(patterns, *converter.getDialect(),
|
SubIOpLowering, XOrOpLowering>(*converter.getDialect(), converter);
|
||||||
converter);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert types using the stored LLVM IR module.
|
// Convert types using the stored LLVM IR module.
|
||||||
|
@ -201,6 +201,6 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
|
|||||||
OwningRewritePatternList &patterns) {
|
OwningRewritePatternList &patterns) {
|
||||||
populateWithGenerated(context, &patterns);
|
populateWithGenerated(context, &patterns);
|
||||||
// Add the return op conversion.
|
// Add the return op conversion.
|
||||||
RewriteListBuilder<ReturnToSPIRVConversion>::build(patterns, context);
|
patterns.insert<ReturnToSPIRVConversion>(context);
|
||||||
}
|
}
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
@ -368,8 +368,7 @@ void LowerUniformRealMathPass::runOnFunction() {
|
|||||||
auto fn = getFunction();
|
auto fn = getFunction();
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
auto *context = &getContext();
|
auto *context = &getContext();
|
||||||
patterns.push_back(llvm::make_unique<UniformRealAddEwPattern>(context));
|
patterns.insert<UniformRealAddEwPattern, UniformRealMulEwPattern>(context);
|
||||||
patterns.push_back(llvm::make_unique<UniformRealMulEwPattern>(context));
|
|
||||||
applyPatternsGreedily(fn, std::move(patterns));
|
applyPatternsGreedily(fn, std::move(patterns));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -389,7 +388,7 @@ void LowerUniformCastsPass::runOnFunction() {
|
|||||||
auto fn = getFunction();
|
auto fn = getFunction();
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
auto *context = &getContext();
|
auto *context = &getContext();
|
||||||
patterns.push_back(llvm::make_unique<UniformDequantizePattern>(context));
|
patterns.insert<UniformDequantizePattern>(context);
|
||||||
applyPatternsGreedily(fn, std::move(patterns));
|
applyPatternsGreedily(fn, std::move(patterns));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -372,7 +372,7 @@ class PropagateConstantBounds : public OpRewritePattern<LaunchOp> {
|
|||||||
|
|
||||||
void LaunchOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
void LaunchOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
RewriteListBuilder<PropagateConstantBounds>::build(results, context);
|
results.insert<PropagateConstantBounds>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -60,8 +60,7 @@ public:
|
|||||||
|
|
||||||
void StorageCastOp::getCanonicalizationPatterns(
|
void StorageCastOp::getCanonicalizationPatterns(
|
||||||
OwningRewritePatternList &patterns, MLIRContext *context) {
|
OwningRewritePatternList &patterns, MLIRContext *context) {
|
||||||
patterns.push_back(
|
patterns.insert<RemoveRedundantStorageCastsRewrite>(context);
|
||||||
llvm::make_unique<RemoveRedundantStorageCastsRewrite>(context));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
QuantizationDialect::QuantizationDialect(MLIRContext *context)
|
QuantizationDialect::QuantizationDialect(MLIRContext *context)
|
||||||
|
@ -108,7 +108,7 @@ void ConvertConstPass::runOnFunction() {
|
|||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
auto func = getFunction();
|
auto func = getFunction();
|
||||||
auto *context = &getContext();
|
auto *context = &getContext();
|
||||||
patterns.push_back(llvm::make_unique<QuantizedConstRewrite>(context));
|
patterns.insert<QuantizedConstRewrite>(context);
|
||||||
applyPatternsGreedily(func, std::move(patterns));
|
applyPatternsGreedily(func, std::move(patterns));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -97,8 +97,7 @@ void ConvertSimulatedQuantPass::runOnFunction() {
|
|||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
auto func = getFunction();
|
auto func = getFunction();
|
||||||
auto *context = &getContext();
|
auto *context = &getContext();
|
||||||
patterns.push_back(
|
patterns.insert<ConstFakeQuantRewrite>(context, &hadFailure);
|
||||||
llvm::make_unique<ConstFakeQuantRewrite>(context, &hadFailure));
|
|
||||||
applyPatternsGreedily(func, std::move(patterns));
|
applyPatternsGreedily(func, std::move(patterns));
|
||||||
if (hadFailure)
|
if (hadFailure)
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
|
11
third_party/mlir/lib/IR/PatternMatch.cpp
vendored
11
third_party/mlir/lib/IR/PatternMatch.cpp
vendored
@ -149,12 +149,13 @@ void PatternRewriter::updatedRootInPlace(
|
|||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
RewritePatternMatcher::RewritePatternMatcher(
|
RewritePatternMatcher::RewritePatternMatcher(
|
||||||
OwningRewritePatternList &&patterns)
|
OwningRewritePatternList &patterns) {
|
||||||
: patterns(std::move(patterns)) {
|
for (auto &pattern : patterns)
|
||||||
|
this->patterns.push_back(pattern.get());
|
||||||
|
|
||||||
// Sort the patterns by benefit to simplify the matching logic.
|
// Sort the patterns by benefit to simplify the matching logic.
|
||||||
std::stable_sort(this->patterns.begin(), this->patterns.end(),
|
std::stable_sort(this->patterns.begin(), this->patterns.end(),
|
||||||
[](const std::unique_ptr<RewritePattern> &l,
|
[](RewritePattern *l, RewritePattern *r) {
|
||||||
const std::unique_ptr<RewritePattern> &r) {
|
|
||||||
return r->getBenefit() < l->getBenefit();
|
return r->getBenefit() < l->getBenefit();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -162,7 +163,7 @@ RewritePatternMatcher::RewritePatternMatcher(
|
|||||||
/// Try to match the given operation to a pattern and rewrite it.
|
/// Try to match the given operation to a pattern and rewrite it.
|
||||||
bool RewritePatternMatcher::matchAndRewrite(Operation *op,
|
bool RewritePatternMatcher::matchAndRewrite(Operation *op,
|
||||||
PatternRewriter &rewriter) {
|
PatternRewriter &rewriter) {
|
||||||
for (auto &pattern : patterns) {
|
for (auto *pattern : patterns) {
|
||||||
// Ignore patterns that are for the wrong root or are impossible to match.
|
// Ignore patterns that are for the wrong root or are impossible to match.
|
||||||
if (pattern->getRootKind() != op->getName() ||
|
if (pattern->getRootKind() != op->getName() ||
|
||||||
pattern->getBenefit().isImpossibleToMatch())
|
pattern->getBenefit().isImpossibleToMatch())
|
||||||
|
@ -678,12 +678,11 @@ static void
|
|||||||
populateLinalgToLLVMConversionPatterns(LinalgTypeConverter &converter,
|
populateLinalgToLLVMConversionPatterns(LinalgTypeConverter &converter,
|
||||||
OwningRewritePatternList &patterns,
|
OwningRewritePatternList &patterns,
|
||||||
MLIRContext *ctx) {
|
MLIRContext *ctx) {
|
||||||
RewriteListBuilder<BufferAllocOpConversion, BufferDeallocOpConversion,
|
patterns.insert<BufferAllocOpConversion, BufferDeallocOpConversion,
|
||||||
BufferSizeOpConversion, DimOpConversion,
|
BufferSizeOpConversion, DimOpConversion,
|
||||||
LinalgOpConversion<DotOp>, LinalgOpConversion<MatmulOp>,
|
LinalgOpConversion<DotOp>, LinalgOpConversion<MatmulOp>,
|
||||||
LoadOpConversion, RangeOpConversion, SliceOpConversion,
|
LoadOpConversion, RangeOpConversion, SliceOpConversion,
|
||||||
StoreOpConversion, ViewOpConversion>::build(patterns, ctx,
|
StoreOpConversion, ViewOpConversion>(ctx, converter);
|
||||||
converter);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -60,12 +60,9 @@ void RemoveInstrumentationPass::runOnFunction() {
|
|||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
auto func = getFunction();
|
auto func = getFunction();
|
||||||
auto *context = &getContext();
|
auto *context = &getContext();
|
||||||
patterns.push_back(
|
patterns.insert<RemoveIdentityOpRewrite<StatisticsOp>,
|
||||||
llvm::make_unique<RemoveIdentityOpRewrite<StatisticsOp>>(context));
|
RemoveIdentityOpRewrite<StatisticsRefOp>,
|
||||||
patterns.push_back(
|
RemoveIdentityOpRewrite<CoupledRefOp>>(context);
|
||||||
llvm::make_unique<RemoveIdentityOpRewrite<StatisticsRefOp>>(context));
|
|
||||||
patterns.push_back(
|
|
||||||
llvm::make_unique<RemoveIdentityOpRewrite<CoupledRefOp>>(context));
|
|
||||||
applyPatternsGreedily(func, std::move(patterns));
|
applyPatternsGreedily(func, std::move(patterns));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
25
third_party/mlir/lib/StandardOps/Ops.cpp
vendored
25
third_party/mlir/lib/StandardOps/Ops.cpp
vendored
@ -365,8 +365,7 @@ struct SimplifyDeadAlloc : public OpRewritePattern<AllocOp> {
|
|||||||
|
|
||||||
void AllocOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
void AllocOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
RewriteListBuilder<SimplifyAllocConst, SimplifyDeadAlloc>::build(results,
|
results.insert<SimplifyAllocConst, SimplifyDeadAlloc>(context);
|
||||||
context);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -544,8 +543,7 @@ static LogicalResult verify(CallIndirectOp op) {
|
|||||||
|
|
||||||
void CallIndirectOp::getCanonicalizationPatterns(
|
void CallIndirectOp::getCanonicalizationPatterns(
|
||||||
OwningRewritePatternList &results, MLIRContext *context) {
|
OwningRewritePatternList &results, MLIRContext *context) {
|
||||||
results.push_back(
|
results.insert<SimplifyIndirectCallWithKnownCallee>(context);
|
||||||
llvm::make_unique<SimplifyIndirectCallWithKnownCallee>(context));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -1015,7 +1013,7 @@ static void print(OpAsmPrinter *p, CondBranchOp op) {
|
|||||||
|
|
||||||
void CondBranchOp::getCanonicalizationPatterns(
|
void CondBranchOp::getCanonicalizationPatterns(
|
||||||
OwningRewritePatternList &results, MLIRContext *context) {
|
OwningRewritePatternList &results, MLIRContext *context) {
|
||||||
results.push_back(llvm::make_unique<SimplifyConstCondBranchPred>(context));
|
results.insert<SimplifyConstCondBranchPred>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -1231,9 +1229,8 @@ static LogicalResult verify(DeallocOp op) {
|
|||||||
void DeallocOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
void DeallocOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
/// dealloc(memrefcast) -> dealloc
|
/// dealloc(memrefcast) -> dealloc
|
||||||
results.push_back(
|
results.insert<MemRefCastFolder>(getOperationName(), context);
|
||||||
llvm::make_unique<MemRefCastFolder>(getOperationName(), context));
|
results.insert<SimplifyDeadDealloc>(context);
|
||||||
results.push_back(llvm::make_unique<SimplifyDeadDealloc>(context));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -1497,8 +1494,7 @@ LogicalResult DmaStartOp::verify() {
|
|||||||
void DmaStartOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
void DmaStartOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
/// dma_start(memrefcast) -> dma_start
|
/// dma_start(memrefcast) -> dma_start
|
||||||
results.push_back(
|
results.insert<MemRefCastFolder>(getOperationName(), context);
|
||||||
llvm::make_unique<MemRefCastFolder>(getOperationName(), context));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
@ -1561,8 +1557,7 @@ ParseResult DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) {
|
|||||||
void DmaWaitOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
void DmaWaitOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
/// dma_wait(memrefcast) -> dma_wait
|
/// dma_wait(memrefcast) -> dma_wait
|
||||||
results.push_back(
|
results.insert<MemRefCastFolder>(getOperationName(), context);
|
||||||
llvm::make_unique<MemRefCastFolder>(getOperationName(), context));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -1695,8 +1690,7 @@ static LogicalResult verify(LoadOp op) {
|
|||||||
void LoadOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
void LoadOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
/// load(memrefcast) -> load
|
/// load(memrefcast) -> load
|
||||||
results.push_back(
|
results.insert<MemRefCastFolder>(getOperationName(), context);
|
||||||
llvm::make_unique<MemRefCastFolder>(getOperationName(), context));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -2007,8 +2001,7 @@ static LogicalResult verify(StoreOp op) {
|
|||||||
void StoreOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
void StoreOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
/// store(memrefcast) -> store
|
/// store(memrefcast) -> store
|
||||||
results.push_back(
|
results.insert<MemRefCastFolder>(getOperationName(), context);
|
||||||
llvm::make_unique<MemRefCastFolder>(getOperationName(), context));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -1243,8 +1243,7 @@ struct FuncOpSignatureConversion : public ConversionPattern {
|
|||||||
void mlir::populateFuncOpTypeConversionPattern(
|
void mlir::populateFuncOpTypeConversionPattern(
|
||||||
OwningRewritePatternList &patterns, MLIRContext *ctx,
|
OwningRewritePatternList &patterns, MLIRContext *ctx,
|
||||||
TypeConverter &converter) {
|
TypeConverter &converter) {
|
||||||
RewriteListBuilder<FuncOpSignatureConversion>::build(patterns, ctx,
|
patterns.insert<FuncOpSignatureConversion>(ctx, converter);
|
||||||
converter);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// This function converts the type signature of the given block, by invoking
|
/// This function converts the type signature of the given block, by invoking
|
||||||
|
@ -507,10 +507,11 @@ public:
|
|||||||
|
|
||||||
void mlir::populateAffineToStdConversionPatterns(
|
void mlir::populateAffineToStdConversionPatterns(
|
||||||
OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
||||||
RewriteListBuilder<AffineApplyLowering, AffineDmaStartLowering,
|
patterns
|
||||||
AffineDmaWaitLowering, AffineLoadLowering,
|
.insert<AffineApplyLowering, AffineDmaStartLowering,
|
||||||
AffineStoreLowering, AffineForLowering, AffineIfLowering,
|
AffineDmaWaitLowering, AffineLoadLowering, AffineStoreLowering,
|
||||||
AffineTerminatorLowering>::build(patterns, ctx);
|
AffineForLowering, AffineIfLowering, AffineTerminatorLowering>(
|
||||||
|
ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -365,12 +365,8 @@ struct LowerVectorTransfersPass
|
|||||||
void runOnFunction() {
|
void runOnFunction() {
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
auto *context = &getContext();
|
auto *context = &getContext();
|
||||||
patterns.push_back(
|
patterns.insert<VectorTransferRewriter<VectorTransferReadOp>,
|
||||||
llvm::make_unique<VectorTransferRewriter<VectorTransferReadOp>>(
|
VectorTransferRewriter<VectorTransferWriteOp>>(context);
|
||||||
context));
|
|
||||||
patterns.push_back(
|
|
||||||
llvm::make_unique<VectorTransferRewriter<VectorTransferWriteOp>>(
|
|
||||||
context));
|
|
||||||
applyPatternsGreedily(getFunction(), std::move(patterns));
|
applyPatternsGreedily(getFunction(), std::move(patterns));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -44,8 +44,8 @@ namespace {
|
|||||||
class GreedyPatternRewriteDriver : public PatternRewriter {
|
class GreedyPatternRewriteDriver : public PatternRewriter {
|
||||||
public:
|
public:
|
||||||
explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
|
explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
|
||||||
OwningRewritePatternList &&patterns)
|
OwningRewritePatternList &patterns)
|
||||||
: PatternRewriter(ctx), matcher(std::move(patterns)) {
|
: PatternRewriter(ctx), matcher(patterns) {
|
||||||
worklist.reserve(64);
|
worklist.reserve(64);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -224,7 +224,7 @@ bool mlir::applyPatternsGreedily(Operation *op,
|
|||||||
if (!op->isKnownIsolatedFromAbove())
|
if (!op->isKnownIsolatedFromAbove())
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
GreedyPatternRewriteDriver driver(op->getContext(), std::move(patterns));
|
GreedyPatternRewriteDriver driver(op->getContext(), patterns);
|
||||||
bool converged = driver.simplify(op, maxPatternMatchIterations);
|
bool converged = driver.simplify(op, maxPatternMatchIterations);
|
||||||
LLVM_DEBUG(if (!converged) {
|
LLVM_DEBUG(if (!converged) {
|
||||||
llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "
|
llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "
|
||||||
|
@ -41,7 +41,7 @@ struct TestPatternDriver : public FunctionPass<TestPatternDriver> {
|
|||||||
populateWithGenerated(&getContext(), &patterns);
|
populateWithGenerated(&getContext(), &patterns);
|
||||||
|
|
||||||
// Verify named pattern is generated with expected name.
|
// Verify named pattern is generated with expected name.
|
||||||
RewriteListBuilder<TestNamedPatternRule>::build(patterns, &getContext());
|
patterns.insert<TestNamedPatternRule>(&getContext());
|
||||||
|
|
||||||
applyPatternsGreedily(getFunction(), std::move(patterns));
|
applyPatternsGreedily(getFunction(), std::move(patterns));
|
||||||
}
|
}
|
||||||
@ -193,9 +193,9 @@ struct TestLegalizePatternDriver
|
|||||||
TestTypeConverter converter;
|
TestTypeConverter converter;
|
||||||
mlir::OwningRewritePatternList patterns;
|
mlir::OwningRewritePatternList patterns;
|
||||||
populateWithGenerated(&getContext(), &patterns);
|
populateWithGenerated(&getContext(), &patterns);
|
||||||
RewriteListBuilder<TestRegionRewriteBlockMovement, TestRegionRewriteUndo,
|
patterns.insert<TestRegionRewriteBlockMovement, TestRegionRewriteUndo,
|
||||||
TestDropOp, TestPassthroughInvalidOp,
|
TestDropOp, TestPassthroughInvalidOp, TestSplitReturnType>(
|
||||||
TestSplitReturnType>::build(patterns, &getContext());
|
&getContext());
|
||||||
mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(),
|
mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(),
|
||||||
converter);
|
converter);
|
||||||
|
|
||||||
|
@ -133,7 +133,7 @@ static LogicalResult runMLIRPasses(ModuleOp m) {
|
|||||||
pm.addPass(createConvertToLLVMIRPass([](LLVMTypeConverter &converter,
|
pm.addPass(createConvertToLLVMIRPass([](LLVMTypeConverter &converter,
|
||||||
OwningRewritePatternList &patterns) {
|
OwningRewritePatternList &patterns) {
|
||||||
populateStdToLLVMConversionPatterns(converter, patterns);
|
populateStdToLLVMConversionPatterns(converter, patterns);
|
||||||
patterns.push_back(llvm::make_unique<GPULaunchFuncOpLowering>(converter));
|
patterns.insert<GPULaunchFuncOpLowering>(converter);
|
||||||
}));
|
}));
|
||||||
pm.addPass(createLowerGpuOpsToNVVMOpsPass());
|
pm.addPass(createLowerGpuOpsToNVVMOpsPass());
|
||||||
pm.addPass(createConvertGPUKernelToCubinPass(&compilePtxToCubin));
|
pm.addPass(createConvertGPUKernelToCubinPass(&compilePtxToCubin));
|
||||||
|
@ -935,8 +935,7 @@ static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
|
|||||||
os << "void populateWithGenerated(MLIRContext *context, "
|
os << "void populateWithGenerated(MLIRContext *context, "
|
||||||
<< "OwningRewritePatternList *patterns) {\n";
|
<< "OwningRewritePatternList *patterns) {\n";
|
||||||
for (const auto &name : rewriterNames) {
|
for (const auto &name : rewriterNames) {
|
||||||
os << " patterns->push_back(llvm::make_unique<" << name
|
os << " patterns->insert<" << name << ">(context);\n";
|
||||||
<< ">(context));\n";
|
|
||||||
}
|
}
|
||||||
os << "}\n";
|
os << "}\n";
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user