diff --git a/third_party/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h b/third_party/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h index 65cb2e63dc0..354c94cdd68 100644 --- a/third_party/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h +++ b/third_party/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h @@ -22,6 +22,8 @@ #include "mlir/IR/OpDefinition.h" namespace mlir { +class FuncOp; + namespace linalg { class LinalgOp; @@ -71,6 +73,8 @@ public: enum DependenceType { RAR = 0, RAW, WAR, WAW, NumTypes }; + // Builds a linalg dependence graph for the ops of type LinalgOp under `f`. + static LinalgDependenceGraph buildDependenceGraph(Aliases &aliases, FuncOp f); LinalgDependenceGraph(Aliases &aliases, ArrayRef ops); /// Returns the X such that op -> X is a dependence of type dt. diff --git a/third_party/mlir/include/mlir/Dialect/Linalg/Passes.h b/third_party/mlir/include/mlir/Dialect/Linalg/Passes.h index f4eb79786e3..fb68c0ae9c3 100644 --- a/third_party/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/third_party/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -42,7 +42,6 @@ std::unique_ptr> createLowerLinalgToLoopsPass(); std::unique_ptr> createLowerLinalgToLLVMPass(); -std::unique_ptr> createLinalgTransformsPass(); } // namespace linalg } // namespace mlir diff --git a/third_party/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td b/third_party/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td index aef7e76f859..9cc4ea3218e 100644 --- a/third_party/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td +++ b/third_party/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td @@ -33,16 +33,44 @@ class HasLinalgTransformMarker : CPred<[{ $0.getAttrOfType(kLinalgTransformMarker).getValue() == "}] # value # [{"}]>; +class IsProducedByOpOfType : + CPred<"isProducedByOpOfType<" # value # ">($0, $1)">; + //===----------------------------------------------------------------------===// -// Linalg transformation patterns. +// Linalg fusion patterns. //===----------------------------------------------------------------------===// +// +// In the future, tile sizes should be derived from op properties + machine +// model but we do not need to wait on this to start having useful patterns. +class TileAndFuseLinalgOp sizes, string value> : NativeCodeCall< + "if (failed(tileAndFuseLinalgOpAndSetMarker($_builder, $0, {" # + StrJoinInt.result # "}, \"" # value # "\")))" # + " return matchFailure();">; + +def : Pat<(MatmulOp:$consumer $A, $B, $C), + (TileAndFuseLinalgOp<[100, 150], "L1"> $consumer), + [ + (Constraint $consumer), + (Constraint> $consumer, $A), + ], + // In the buffer world there is no use-def chains or dags so benefits + // cannot be computed automatically from the length of the matched + // pattern. Instead we specify the benefit ourselves for now. + // This is not expected to be a big challenge long-term because + // pattern benefits are akin to feature engineering: features should + // be learned. + (addBenefit 1)>; + +//===----------------------------------------------------------------------===// +// Linalg tiling patterns. +//===----------------------------------------------------------------------===// +// +// In the future, tile sizes should be derived from op properties + machine +// model but we do not need to wait on this to start having useful patterns. class TileLinalgOp sizes, string value> : NativeCodeCall< - "auto res = tileLinalgOperation($_builder, $0, ArrayRef{" # - StrJoinInt.result # "});" # [{ - if (!res) - return matchFailure(); - res->op.setAttr(kLinalgTransformMarker, StringAttr::get("}] # value # - [{", $0.getContext()));}]>; + "if (failed(tileLinalgOpAndSetMarker($_builder, $0, {" # + StrJoinInt.result # "}, \"" # value # "\")))" # + " return matchFailure();">; def : Pat<(MatmulOp:$op $A, $B, $C), (TileLinalgOp<[2000, 3000, 4000], "L3"> $op), diff --git a/third_party/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/third_party/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h index 0bfcdea2007..0401d6987aa 100644 --- a/third_party/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/third_party/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -81,8 +81,21 @@ struct FusionInfo { LinalgOp fusedProducer; }; -// Fuses producer into consumer if the producer is structurally feasible and the -// fusion would not violate dependencies. +/// Checks whether the specific `producer` is the last write to exactly the +/// whole `consumedView`. This checks structural dominance, that the dependence +/// is a RAW without any interleaved write to any piece of `consumedView`. +bool isProducerLastWriteOfView(const LinalgDependenceGraph &graph, + LinalgOp consumer, Value *consumedView, + LinalgOp producer); + +/// Checks whether fusing the specific `producer` of the `consumedView` is +/// feasible. This checks `producer` is the last write of `consumedView` and +/// that no interleaved dependence would be violated (RAW, WAR or WAW). +bool isFusableInto(const LinalgDependenceGraph &graph, LinalgOp consumer, + Value *consumedView, LinalgOp producer); + +/// Fuses producer into consumer if the producer is structurally feasible and +/// the fusion would not violate dependencies. /// When non-null, the optional pointer `folder` is used to call into the /// `createAndFold` builder method. If `folder` is null, the regular `create` /// method is called. diff --git a/third_party/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp b/third_party/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp index 3a90e61ed10..9e57b7bb9de 100644 --- a/third_party/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp +++ b/third_party/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp @@ -86,6 +86,13 @@ Value *Aliases::find(Value *v) { } } +LinalgDependenceGraph +LinalgDependenceGraph::buildDependenceGraph(Aliases &aliases, FuncOp f) { + SmallVector linalgOps; + f.walk([&](LinalgOp op) { linalgOps.push_back(op); }); + return LinalgDependenceGraph(aliases, linalgOps); +} + LinalgDependenceGraph::LinalgDependenceGraph(Aliases &aliases, ArrayRef ops) : aliases(aliases), linalgOps(ops.begin(), ops.end()) { diff --git a/third_party/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/third_party/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp index 8e7370af7a3..82699545b3f 100644 --- a/third_party/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/third_party/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -201,19 +201,13 @@ static LinalgOp fuse(Value *producedView, LinalgOp producer, LinalgOp consumer, // Encode structural fusion safety preconditions. // Some of these will be lifted in the future with better analysis. -static bool isStructurallyFusableProducer(LinalgOp producer, Value *readView, +static bool isStructurallyFusableProducer(LinalgOp producer, + Value *consumedView, LinalgOp consumer) { if (producer.getNumOutputs() != 1) { LLVM_DEBUG(dbgs() << "\nNot structurally fusable (multi-output)"); return false; } - // Must be a subview or a slice to guarantee there are loops we can fuse into. - auto subView = dyn_cast_or_null(readView->getDefiningOp()); - auto slice = dyn_cast_or_null(readView->getDefiningOp()); - if (!subView && !slice) { - LLVM_DEBUG(dbgs() << "\nNot structurally fusable (not a subview or slice)"); - return false; - } // Only fuse when the producer block dominates. DominanceInfo dom(producer.getOperation()); if (!dom.dominates(producer.getOperation()->getBlock(), @@ -226,6 +220,41 @@ static bool isStructurallyFusableProducer(LinalgOp producer, Value *readView, return true; } +bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph, + LinalgOp consumer, + Value *consumedView, + LinalgOp producer) { + // Make some simple structural checks that alleviate the need for more + // complex analyses. + if (!isStructurallyFusableProducer(producer, consumedView, consumer)) { + LLVM_DEBUG(dbgs() << "\n***Not static last write due to structure:\t" + << *producer.getOperation()); + return false; + } + // Check for any interleaved write to consumedView. + if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) { + LLVM_DEBUG(dbgs() << "\n***Not fusable due to interleaved write:\t" + << *producer.getOperation()); + return false; + } + return true; +} + +bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph, + LinalgOp consumer, Value *consumedView, + LinalgOp producer) { + if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer)) + return false; + // Check for any fusion-preventing dependence to any view read/written that + // would violate dependences. + if (!graph.findCoveringDependences(producer, consumer).empty()) { + LLVM_DEBUG(dbgs() << "\n***Not fusable due to an interleaved dependence:\t" + << *producer.getOperation()); + return false; + } + return true; +} + // Only consider RAW atm. Optional mlir::linalg::fuseProducerOf( OpBuilder &b, LinalgOp consumer, unsigned consumerIdx, @@ -239,8 +268,8 @@ Optional mlir::linalg::fuseProducerOf( auto producer = cast(dependence.dependentOpView.op); // Check that the dependence is indeed on the input `consumerIdx` view. - auto *readView = dependence.indexingView; - if (consumer.getInput(consumerIdx) != readView) + auto *consumedView = dependence.indexingView; + if (consumer.getInput(consumerIdx) != consumedView) continue; // Consumer consumes this view, `isStructurallyFusableProducer` also checks @@ -252,16 +281,17 @@ Optional mlir::linalg::fuseProducerOf( << " view: " << *producedView << " output index: " << producerIdx); - // Make some simple structural checks that alleviate the need for more - // complex analyses. - if (!isStructurallyFusableProducer(producer, readView, consumer)) { - LLVM_DEBUG(dbgs() << "\n***Not fusable:\t" << *producer.getOperation()); + // Must be a subview or a slice to guarantee there are loops we can fuse + // into. + auto subView = dyn_cast_or_null(consumedView->getDefiningOp()); + auto slice = dyn_cast_or_null(consumedView->getDefiningOp()); + if (!subView && !slice) { + LLVM_DEBUG(dbgs() << "\nNot fusable (not a subview or slice)"); continue; } - // Check for fusion-preventing write that would violate dependences. - // `view` is a producer write that cannot bypass any other write or read. - if (!graph.findCoveringDependences(producer, consumer).empty()) + // Simple fusability checks. + if (!isFusableInto(graph, consumer, consumedView, producer)) continue; // Fuse `producer` just before `consumer`. diff --git a/third_party/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/third_party/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp index 118018b9372..aaa7d9dabf6 100644 --- a/third_party/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp +++ b/third_party/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp @@ -19,6 +19,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" @@ -26,11 +27,81 @@ #include "mlir/Pass/Pass.h" using namespace mlir; -using mlir::linalg::LinalgOp; +using namespace mlir::linalg; // Marker used as attribute name in generated Linalg rewriting transformations. static constexpr auto kLinalgTransformMarker = "__internal_linalg_transform__"; +static LogicalResult tileLinalgOpAndSetMarker(PatternRewriter &rewriter, + Operation *op, + ArrayRef sizes, + StringRef linalgMarker) { + auto tileRes = tileLinalgOperation(rewriter, op, sizes); + if (!tileRes) + return failure(); + tileRes->op.setAttr(kLinalgTransformMarker, + rewriter.getStringAttr(linalgMarker)); + tileRes->op.getParentOfType().dump(); + return success(); +} + +static LogicalResult tileAndFuseLinalgOpAndSetMarker(PatternRewriter &rewriter, + Operation *op, + ArrayRef sizes, + StringRef linalgMarker) { + auto tileRes = tileLinalgOperation(rewriter, op, sizes); + if (!tileRes) + return failure(); + tileRes->op.setAttr(kLinalgTransformMarker, + rewriter.getStringAttr(linalgMarker)); + Aliases aliases; + auto G = LinalgDependenceGraph::buildDependenceGraph( + aliases, op->getParentOfType()); + auto fusionRes = fuseProducerOf(rewriter, tileRes->op, 0, G); + if (!fusionRes) { + // Linalg fusion requires tiled loops to even determine whether it is + // possible to fuse. As a consequence, the pattern may fail even though a + // tiled version of op has already been introduced. + // So we need to remove the tiled version ourselves in case of failure. + // Another possibility is to ensure the constraints on the pattern guarantee + // that fusion will occur and just assert here. + // As we develop more complex patterns we can choose what is best. + rewriter.eraseOp(tileRes->loops[0]); + return failure(); + } + fusionRes->fusedProducer.setAttr(kLinalgTransformMarker, + rewriter.getStringAttr(linalgMarker)); + // The originalProducer can now be safely erased. This is similar to SSA-value + // use-def but in the world of buffer + structured ops. + rewriter.eraseOp(fusionRes->originalProducer); + fusionRes->fusedProducer.getParentOfType().dump(); + return success(); +} + +template +bool isProducedByOpOfType(Operation *consumerOp, Value *consumedView) { + LinalgOp consumer = dyn_cast(consumerOp); + if (!consumer) + return false; + + auto maybeConsumerIndex = consumer.getIndexOfInput(consumedView); + if (!maybeConsumerIndex) + return false; + + Aliases aliases; + auto G = LinalgDependenceGraph::buildDependenceGraph( + aliases, consumer.getParentOfType()); + for (auto dependence : G.getDependencesInto( + consumer, LinalgDependenceGraph::DependenceType::RAW)) { + auto producer = cast(dependence.dependentOpView.op); + if (!isProducerLastWriteOfView(G, consumer, consumedView, producer)) + continue; + if (isa(dependence.dependentOpView.op)) + return true; + } + return false; +} + namespace mlir { namespace linalg { namespace { @@ -58,10 +129,6 @@ void LinalgTransforms::runOnFunction() { funcOp.walk([](LinalgOp op) { op.removeAttr(kLinalgTransformMarker); }); } -std::unique_ptr> mlir::linalg::createLinalgTransformsPass() { - return std::make_unique(); -} - static PassRegistration pass("test-linalg-transform-patterns", "Test Linalg transformation patterns by applying them greedily.");