diff --git a/third_party/mlir/BUILD b/third_party/mlir/BUILD index 7ba88d877d3..b081ad194e5 100644 --- a/third_party/mlir/BUILD +++ b/third_party/mlir/BUILD @@ -1738,6 +1738,7 @@ cc_library( "include/mlir/Linalg/Utils/Utils.h", ], deps = [ + ":AffineOps", ":CFGTransforms", ":EDSC", ":IR", diff --git a/third_party/mlir/include/mlir/EDSC/Intrinsics.h b/third_party/mlir/include/mlir/EDSC/Intrinsics.h index 021fec2f444..6870e029ce8 100644 --- a/third_party/mlir/include/mlir/EDSC/Intrinsics.h +++ b/third_party/mlir/include/mlir/EDSC/Intrinsics.h @@ -61,20 +61,32 @@ struct IndexHandle : public ValueHandle { this->v = v.getValue(); return *this; } - static SmallVector makeIndexHandles(unsigned rank) { - return SmallVector(rank); - } - static SmallVector - makeIndexHandlePointers(SmallVectorImpl &ivs) { - SmallVector pivs; - pivs.reserve(ivs.size()); - for (auto &iv : ivs) { - pivs.push_back(&iv); - } - return pivs; - } }; +inline SmallVector makeIndexHandles(unsigned rank) { + return SmallVector(rank); +} + +inline SmallVector +makeIndexHandlePointers(MutableArrayRef ivs) { + SmallVector pivs; + pivs.reserve(ivs.size()); + for (auto &iv : ivs) { + pivs.push_back(&iv); + } + return pivs; +} + +/// Returns a vector of the underlying Value* from `ivs`. +inline SmallVector extractValues(ArrayRef ivs) { + SmallVector vals; + vals.reserve(ivs.size()); + for (auto &iv : ivs) { + vals.push_back(iv.getValue()); + } + return vals; +} + /// Provides a set of first class intrinsics. /// In the future, most of intrinsics related to Operation that don't contain /// other operations should be Tablegen'd. diff --git a/third_party/mlir/include/mlir/Linalg/IR/LinalgOps.h b/third_party/mlir/include/mlir/Linalg/IR/LinalgOps.h index 41767ad6f90..511f8035d72 100644 --- a/third_party/mlir/include/mlir/Linalg/IR/LinalgOps.h +++ b/third_party/mlir/include/mlir/Linalg/IR/LinalgOps.h @@ -436,11 +436,6 @@ private: }; }; -void emitScalarImplementation(llvm::ArrayRef parallelIvs, - llvm::ArrayRef reductionIvs, - llvm::ArrayRef windowIvs, - LinalgOp &linalgOp, OperationFolder &folder); - } // namespace linalg } // namespace mlir diff --git a/third_party/mlir/include/mlir/Linalg/Utils/Intrinsics.h b/third_party/mlir/include/mlir/Linalg/Utils/Intrinsics.h index c7f3d91282a..eabec69883e 100644 --- a/third_party/mlir/include/mlir/Linalg/Utils/Intrinsics.h +++ b/third_party/mlir/include/mlir/Linalg/Utils/Intrinsics.h @@ -27,8 +27,10 @@ class BufferDeallocOp; class CopyOp; class DimOp; class FillOp; +class LoadOp; class RangeOp; class SliceOp; +class StoreOp; class ViewOp; namespace intrinsics { using buffer_alloc = mlir::edsc::intrinsics::ValueBuilder; @@ -37,6 +39,8 @@ using buffer_dealloc = using copy = mlir::edsc::intrinsics::OperationBuilder; using dim = mlir::edsc::intrinsics::ValueBuilder; using fill = mlir::edsc::intrinsics::OperationBuilder; +using linalg_load = mlir::edsc::intrinsics::ValueBuilder; +using linalg_store = mlir::edsc::intrinsics::OperationBuilder; using range = mlir::edsc::intrinsics::ValueBuilder; using slice = mlir::edsc::intrinsics::ValueBuilder; using view = mlir::edsc::intrinsics::ValueBuilder; diff --git a/third_party/mlir/include/mlir/Linalg/Utils/Utils.h b/third_party/mlir/include/mlir/Linalg/Utils/Utils.h index 1c0335985d7..68d71a8d37c 100644 --- a/third_party/mlir/include/mlir/Linalg/Utils/Utils.h +++ b/third_party/mlir/include/mlir/Linalg/Utils/Utils.h @@ -21,6 +21,7 @@ #include "mlir/Dialect/LoopOps/LoopOps.h" #include "mlir/EDSC/Helpers.h" #include "mlir/Linalg/IR/LinalgOps.h" +#include "mlir/Linalg/Utils/Intrinsics.h" #include "mlir/Support/LLVM.h" namespace mlir { @@ -79,7 +80,16 @@ namespace linalg { /// Returns the linearized list of all view dimensions in a linalgOp. Applying /// the inverse, concatenated loopToOperandRangeMaps to this list allows the /// derivation of loop ranges for any linalgOp. -SmallVector getViewSizes(LinalgOp &linalgOp); +template +SmallVector getViewSizes(ConcreteOp linalgOp) { + SmallVector res; + for (auto v : linalgOp.getInputsAndOutputs()) { + ViewType t = v->getType().template cast(); + for (unsigned i = 0; i < t.getRank(); ++i) + res.push_back(intrinsics::dim(v, i)); + } + return res; +} /// Returns the values obtained by applying `map` to the list of values. /// Performs simplifications and foldings where possible. diff --git a/third_party/mlir/lib/Linalg/CMakeLists.txt b/third_party/mlir/lib/Linalg/CMakeLists.txt index d015940e3c0..b37bdaac440 100644 --- a/third_party/mlir/lib/Linalg/CMakeLists.txt +++ b/third_party/mlir/lib/Linalg/CMakeLists.txt @@ -14,4 +14,11 @@ add_llvm_library(MLIRLinalg DEPENDS intrinsics_gen ) -add_dependencies(MLIRLinalg MLIRLinalgOpsIncGen MLIRLinalgLibraryOpsIncGen MLIRStandardToLLVM) + +add_dependencies(MLIRLinalg + + MLIRAffineOps + MLIRLinalgOpsIncGen + MLIRLinalgLibraryOpsIncGen + MLIRStandardToLLVM + ) diff --git a/third_party/mlir/lib/Linalg/IR/LinalgOps.cpp b/third_party/mlir/lib/Linalg/IR/LinalgOps.cpp index 59bddd302ec..f56470a6914 100644 --- a/third_party/mlir/lib/Linalg/IR/LinalgOps.cpp +++ b/third_party/mlir/lib/Linalg/IR/LinalgOps.cpp @@ -846,23 +846,6 @@ static SmallVector concat(ArrayRef a, return res; } -static SmallVector -foldedAffineApplies(OpBuilder &b, Location loc, AffineMap map, - ArrayRef vals, OperationFolder &folder) { - assert(map.getNumSymbols() == 0); - assert(map.getNumInputs() == vals.size()); - SmallVector res; - res.reserve(map.getNumResults()); - auto dims = map.getNumDims(); - for (auto e : map.getResults()) { - auto exprMap = AffineMap::get(dims, 0, e); - SmallVector operands(vals.begin(), vals.end()); - canonicalizeMapAndOperands(&exprMap, &operands); - res.push_back(affine_apply(folder, exprMap, operands)); - } - return res; -} - // Note: both functions below would completely disappear with a simple tensor // kernel language. // @@ -950,164 +933,3 @@ SmallVector mlir::linalg::loopToOperandRangesMaps(Operation *op) { } llvm_unreachable("Missing loopToOperandRangesMaps for op"); } - -static SmallVector permuteIvs(ArrayRef ivs, - Optional permutation, - OperationFolder &state) { - return permutation ? applyMapToValues(ScopedContext::getBuilder(), - ScopedContext::getLocation(), - permutation.getValue(), ivs, state) - : SmallVector(ivs.begin(), ivs.end()); -} - -// Ideally this should all be Tablegen'd but there is no good story for op -// expansion directly in MLIR for now. -void mlir::linalg::emitScalarImplementation( - llvm::ArrayRef parallelIvs, llvm::ArrayRef reductionIvs, - llvm::ArrayRef windowIvs, LinalgOp &linalgOp, - OperationFolder &folder) { - using linalg_load = ValueBuilder; - using linalg_store = OperationBuilder; - using IndexedValue = TemplatedIndexedValue; - using edsc::op::operator+; - using edsc::op::operator*; - using edsc::op::operator==; - using edsc::intrinsics::select; - - auto nPar = parallelIvs.size(); - auto nRed = reductionIvs.size(); - auto nWin = windowIvs.size(); - SmallVector allIvs; - allIvs.reserve(nPar + nRed + nWin); - allIvs.assign(parallelIvs.begin(), parallelIvs.end()); - allIvs.append(reductionIvs.begin(), reductionIvs.end()); - allIvs.append(windowIvs.begin(), windowIvs.end()); - - // Default OpBuilder supports 0-D case (no loops). - OpBuilder b(linalgOp.getOperation()); - auto nLoops = nPar + nRed + nWin; - if (nLoops > 0) { - auto innermostLoop = loop::getForInductionVarOwner(allIvs.back()); - // accounts for linalg.terminator in loop. - b = innermostLoop.getBodyBuilder(); - } - - auto loc = linalgOp.getLoc(); - ScopedContext scope(b, loc); - auto *op = linalgOp.getOperation(); - if (auto copyOp = dyn_cast(op)) { - OperationFolder state; - auto inputIvs = permuteIvs(parallelIvs, copyOp.inputPermutation(), state); - auto outputIvs = permuteIvs(parallelIvs, copyOp.outputPermutation(), state); - SmallVector iivs(inputIvs.begin(), inputIvs.end()); - SmallVector oivs(outputIvs.begin(), outputIvs.end()); - // clang-format off - IndexedValue O(copyOp.getOutput(0)), I(copyOp.getInput(0)); - nLoops > 0 ? - O(oivs) = I(iivs) : - O() = I(); - // clang-format on - return; - } - if (auto fillOp = dyn_cast(op)) { - SmallVector ivs(parallelIvs.begin(), parallelIvs.end()); - // clang-format off - IndexedValue O(fillOp.getOutput(0)); - nLoops > 0 ? - O(ivs) = ValueHandle(fillOp.getValue()) : - O() = ValueHandle(fillOp.getValue()); - // clang-format on - return; - } - if (auto dotOp = dyn_cast(op)) { - IndexHandle r_i(reductionIvs[0]); - IndexedValue A(dotOp.getInput(0)), B(dotOp.getInput(1)), - C(dotOp.getOutput(0)); - C() = C() + A(r_i) * B(r_i); - return; - } - if (auto matvecOp = dyn_cast(op)) { - IndexHandle i(parallelIvs[0]), r_j(reductionIvs[0]); - IndexedValue A(matvecOp.getInput(0)), B(matvecOp.getInput(1)), - C(matvecOp.getOutput(0)); - C(i) = C(i) + A(i, r_j) * B(r_j); - return; - } - if (auto matmulOp = dyn_cast(op)) { - IndexHandle i(parallelIvs[0]), j(parallelIvs[1]), r_k(reductionIvs[0]); - IndexedValue A(matmulOp.getInput(0)), B(matmulOp.getInput(1)), - C(matmulOp.getOutput(0)); - C(i, j) = C(i, j) + A(i, r_k) * B(r_k, j); - return; - } - if (auto convOp = dyn_cast(op)) { - auto maps = loopToOperandRangesMaps(op); - SmallVector fIdx( - foldedAffineApplies(b, loc, maps[0], allIvs, folder)); - SmallVector imIdx( - foldedAffineApplies(b, loc, maps[1], allIvs, folder)); - SmallVector oIdx( - foldedAffineApplies(b, loc, maps[2], allIvs, folder)); - IndexedValue F(convOp.filter()), I(convOp.input()), O(convOp.output()); - O(oIdx) += F(fIdx) * I(imIdx); - return; - } - if (auto genericOp = dyn_cast(op)) { - using edsc::intrinsics::detail::ValueHandleArray; - unsigned nInputs = genericOp.getNumInputs(); - unsigned nOutputs = genericOp.getNumOutputs(); - SmallVector indexedValues(nInputs + nOutputs); - // Emits the MLIR for the scalar part of the generic op by: - // 1. Emitting linalg_load and linalg_store ops for each input and output - // view in order. This is achieved by applying the appropriate input or - // output map to the enclosing induction variables. - // 2. Emitting a call to `op.fun()` that takes as arguments the scalars - // from point 1. above. - // 3. Emitting linalg_store to store the results of 2. to the output - // views. - // - // An example output may resemble: - // - // ``` - // loop.for %i = %c0 to %0 step %c1 { - // loop.for %j = %c0 to %1 step %c1 { - // loop.for %k = %c0 to %4 step %c1 { - // %11 = linalg.load %arg0[%i, %j] : !linalg.view - // %12 = linalg.load %arg1[%i, %j, %k] : !linalg.view - // %13 = linalg.load %arg2[%i, %k, %j] : !linalg.view - // %14:2 = call @foo(%11, %12, %13) : (f32, f32, f32) -> (f32, f32) - // linalg.store %14#0, %arg1[%i, %j, %k] : !linalg.view - // linalg.store %14#1, %arg2[%i, %k, %j] : !linalg.view - // } - // } - // } - // ``` - - // 1.a. Emit linalg_load from input views. - for (unsigned i = 0, e = nInputs; i < e; ++i) { - ValueHandleArray indexing(foldedAffineApplies( - b, loc, genericOp.getInputIndexingMap(i), allIvs, folder)); - indexedValues[i] = linalg_load(genericOp.getInput(i), indexing); - } - // 1.b. Emit linalg_load from output views.. - for (unsigned i = 0, e = nOutputs; i < e; ++i) { - ValueHandleArray indexing(foldedAffineApplies( - b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder)); - indexedValues[nInputs + i] = - linalg_load(genericOp.getOutput(i), indexing); - } - // 2. Emit call. - auto m = genericOp.getParentOfType(); - auto fun = m.lookupSymbol(genericOp.fun()); - Operation *callOp = call(fun, indexedValues); - assert(callOp->getNumResults() == genericOp.getNumOutputs()); - // 3. Emit linalg_store. - for (unsigned i = 0, e = nOutputs; i < e; ++i) { - ValueHandleArray indexing(foldedAffineApplies( - b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder)); - linalg_store(callOp->getResult(i), genericOp.getOutput(i), indexing); - } - return; - } - llvm_unreachable("Missing emitScalarImplementation for op"); -} diff --git a/third_party/mlir/lib/Linalg/Transforms/LowerToLoops.cpp b/third_party/mlir/lib/Linalg/Transforms/LowerToLoops.cpp index 2e616c35f1d..c75ee48aac1 100644 --- a/third_party/mlir/lib/Linalg/Transforms/LowerToLoops.cpp +++ b/third_party/mlir/lib/Linalg/Transforms/LowerToLoops.cpp @@ -15,6 +15,8 @@ // limitations under the License. // ============================================================================= +#include "mlir/AffineOps/AffineOps.h" +#include "mlir/Dialect/LoopOps/LoopOps.h" #include "mlir/EDSC/Helpers.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" @@ -22,17 +24,50 @@ #include "mlir/Linalg/IR/LinalgOps.h" #include "mlir/Linalg/IR/LinalgTypes.h" #include "mlir/Linalg/Passes.h" +#include "mlir/Linalg/Utils/Intrinsics.h" #include "mlir/Linalg/Utils/Utils.h" #include "mlir/Pass/Pass.h" #include "mlir/StandardOps/Ops.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/STLExtras.h" +#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/FoldUtils.h" using namespace mlir; using namespace mlir::edsc; using namespace mlir::edsc::intrinsics; using namespace mlir::linalg; +using namespace mlir::linalg::intrinsics; + +using IndexedLinalgValue = TemplatedIndexedValue; +using edsc::op::operator+; +using edsc::op::operator==; + +static SmallVector +foldedAffineApplies(OpBuilder &b, Location loc, AffineMap map, + ArrayRef vals, OperationFolder &folder) { + assert(map.getNumSymbols() == 0); + assert(map.getNumInputs() == vals.size()); + SmallVector res; + res.reserve(map.getNumResults()); + auto dims = map.getNumDims(); + for (auto e : map.getResults()) { + auto exprMap = AffineMap::get(dims, 0, e); + SmallVector operands(vals.begin(), vals.end()); + canonicalizeMapAndOperands(&exprMap, &operands); + res.push_back(affine_apply(folder, exprMap, operands)); + } + return res; +} + +static SmallVector permuteIvs(ArrayRef ivs, + Optional permutation, + OperationFolder &state) { + return permutation ? applyMapToValues(ScopedContext::getBuilder(), + ScopedContext::getLocation(), + permutation.getValue(), ivs, state) + : SmallVector(ivs.begin(), ivs.end()); +} // Creates a number of ranges equal to the number of results in `map`. // The returned ranges correspond to the loop ranges, in the proper order, for @@ -40,61 +75,272 @@ using namespace mlir::linalg; static SmallVector emitLoopRanges(OpBuilder &b, Location loc, AffineMap map, ArrayRef allViewSizes, - OperationFolder &state) { + OperationFolder &folder) { // Apply `map` to get view sizes in loop order. - auto sizes = applyMapToValues(b, loc, map, allViewSizes, state); + auto sizes = applyMapToValues(b, loc, map, allViewSizes, folder); // Create a new range with the applied tile sizes. + ScopedContext scope(b, loc); SmallVector res; for (unsigned idx = 0, e = map.getNumResults(); idx < e; ++idx) { - res.push_back(b.create( - loc, state.create(b, loc, 0), sizes[idx], - state.create(b, loc, 1))); + res.push_back(range(constant_index(folder, 0), sizes[idx], + constant_index(folder, 1))); } return res; } -static void emitLinalgOpAsLoops(LinalgOp &linalgOp, OperationFolder &state) { - OpBuilder b(linalgOp.getOperation()); - ScopedContext scope(b, linalgOp.getOperation()->getLoc()); - // The flattened loopToOperandRangesMaps is expected to be an invertible - // permutation map (which is asserted in the inverse calculation). - auto invertedMap = - inversePermutation(concatAffineMaps(loopToOperandRangesMaps(linalgOp))); - if (!invertedMap) { - mlir::linalg::emitScalarImplementation({}, {}, {}, linalgOp, state); - return; +template class LinalgScopedEmitter {}; + +template <> class LinalgScopedEmitter { +public: + static void emitScalarImplementation(ArrayRef allIvs, CopyOp copyOp, + OperationFolder &folder) { + auto nPar = copyOp.getNumParallelLoops(); + assert(nPar == allIvs.size()); + auto inputIvs = + permuteIvs(allIvs.take_front(nPar), copyOp.inputPermutation(), folder); + auto outputIvs = + permuteIvs(allIvs.take_front(nPar), copyOp.outputPermutation(), folder); + SmallVector iivs(inputIvs.begin(), inputIvs.end()); + SmallVector oivs(outputIvs.begin(), outputIvs.end()); + IndexedLinalgValue O(copyOp.getOutput(0)), I(copyOp.getInput(0)); + // Emit the proper scalar assignment, whether we are dealing with a 0-D or + // an n-D loop nest; with or without permutations. + // clang-format off + nPar > 0 ? O(oivs) = I(iivs) : + O() = I(); + // clang-format on + } +}; + +template <> class LinalgScopedEmitter { +public: + static void emitScalarImplementation(ArrayRef allIvs, FillOp fillOp, + OperationFolder &folder) { + auto nPar = fillOp.getNumParallelLoops(); + assert(nPar == allIvs.size()); + auto ivs = + SmallVector(allIvs.begin(), allIvs.begin() + nPar); + IndexedLinalgValue O(fillOp.getOutput(0)); + // Emit the proper scalar assignment, whether we are dealing with a 0-D or + // an n-D loop nest; with or without permutations. + nPar > 0 ? O(ivs) = ValueHandle(fillOp.getValue()) + : O() = ValueHandle(fillOp.getValue()); + } +}; + +template <> class LinalgScopedEmitter { +public: + static void emitScalarImplementation(ArrayRef allIvs, DotOp dotOp, + OperationFolder &folder) { + assert(allIvs.size() == 1); + IndexHandle r_i(allIvs[0]); + IndexedLinalgValue A(dotOp.getInput(0)), B(dotOp.getInput(1)), + C(dotOp.getOutput(0)); + // Emit scalar form. + C() = C() + A(r_i) * B(r_i); + } +}; + +template <> class LinalgScopedEmitter { +public: + static void emitScalarImplementation(ArrayRef allIvs, + MatvecOp matvecOp, + OperationFolder &folder) { + assert(allIvs.size() == 2); + IndexHandle i(allIvs[0]), r_j(allIvs[1]); + IndexedLinalgValue A(matvecOp.getInput(0)), B(matvecOp.getInput(1)), + C(matvecOp.getOutput(0)); + // Emit scalar form. + C(i) = C(i) + A(i, r_j) * B(r_j); + } +}; + +template <> class LinalgScopedEmitter { +public: + static void emitScalarImplementation(ArrayRef allIvs, + MatmulOp matmulOp, + OperationFolder &folder) { + assert(allIvs.size() == 3); + IndexHandle i(allIvs[0]), j(allIvs[1]), r_k(allIvs[2]); + IndexedLinalgValue A(matmulOp.getInput(0)), B(matmulOp.getInput(1)), + C(matmulOp.getOutput(0)); + // Emit scalar form. + C(i, j) = C(i, j) + A(i, r_k) * B(r_k, j); + } +}; + +template <> class LinalgScopedEmitter { +public: + static void emitScalarImplementation(ArrayRef allIvs, ConvOp convOp, + OperationFolder &folder) { + auto b = ScopedContext::getBuilder(); + auto loc = ScopedContext::getLocation(); + auto maps = loopToOperandRangesMaps(convOp); + SmallVector fIdx( + foldedAffineApplies(b, loc, maps[0], allIvs, folder)); + SmallVector imIdx( + foldedAffineApplies(b, loc, maps[1], allIvs, folder)); + SmallVector oIdx( + foldedAffineApplies(b, loc, maps[2], allIvs, folder)); + IndexedLinalgValue F(convOp.filter()), I(convOp.input()), + O(convOp.output()); + // Emit scalar form. + O(oIdx) += F(fIdx) * I(imIdx); + } +}; + +// Emits the MLIR for the scalar part of the generic op by: +// 1. Emitting linalg_load and linalg_store ops for each input and output +// view in order. This is achieved by applying the appropriate input or +// output map to the enclosing induction variables. +// 2. Emitting a call to `op.fun()` that takes as arguments the scalars +// from point 1. above. +// 3. Emitting linalg_store to store the results of 2. to the output +// views. +// +// An example output may resemble: +// +// ``` +// loop.for %i = %c0 to %0 step %c1 { +// loop.for %j = %c0 to %1 step %c1 { +// loop.for %k = %c0 to %4 step %c1 { +// %11 = linalg.load %arg0[%i, %j] : !linalg.view +// %12 = linalg.load %arg1[%i, %j, %k] : !linalg.view +// %13 = linalg.load %arg2[%i, %k, %j] : !linalg.view +// %14:2 = call @foo(%11, %12, %13) : (f32, f32, f32) -> (f32, f32) +// linalg.store %14#0, %arg1[%i, %j, %k] : !linalg.view +// linalg.store %14#1, %arg2[%i, %k, %j] : !linalg.view +// } +// } +// } +// ``` +template <> class LinalgScopedEmitter { +public: + static void emitScalarImplementation(ArrayRef allIvs, + GenericOp genericOp, + OperationFolder &folder) { + auto b = ScopedContext::getBuilder(); + auto loc = ScopedContext::getLocation(); + using edsc::intrinsics::detail::ValueHandleArray; + unsigned nInputs = genericOp.getNumInputs(); + unsigned nOutputs = genericOp.getNumOutputs(); + SmallVector indexedValues(nInputs + nOutputs); + + // 1.a. Emit linalg_load from input views. + for (unsigned i = 0, e = nInputs; i < e; ++i) { + ValueHandleArray indexing(foldedAffineApplies( + b, loc, genericOp.getInputIndexingMap(i), allIvs, folder)); + indexedValues[i] = linalg_load(genericOp.getInput(i), indexing); + } + + // 1.b. Emit linalg_load from output views. + for (unsigned i = 0, e = nOutputs; i < e; ++i) { + ValueHandleArray indexing(foldedAffineApplies( + b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder)); + indexedValues[nInputs + i] = + linalg_load(genericOp.getOutput(i), indexing); + } + + // 2. Emit call. + auto m = genericOp.getParentOfType(); + auto fun = m.lookupSymbol(genericOp.fun()); + Operation *callOp = call(fun, indexedValues); + assert(callOp->getNumResults() == genericOp.getNumOutputs()); + + // 3. Emit linalg_store. + for (unsigned i = 0, e = nOutputs; i < e; ++i) { + ValueHandleArray indexing(foldedAffineApplies( + b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder)); + linalg_store(callOp->getResult(i), genericOp.getOutput(i), indexing); + } + } +}; + +template +class LinalgRewritePattern : public RewritePattern { +public: + explicit LinalgRewritePattern(MLIRContext *context) + : RewritePattern(ConcreteOp::getOperationName(), /*benefit=*/1, context) { } - auto loopRanges = emitLoopRanges(scope.getBuilder(), scope.getLocation(), - invertedMap, getViewSizes(linalgOp), state); + PatternMatchResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + OpBuilder b(op); + ScopedContext scope(b, op->getLoc()); - SmallVector parallelIvs(linalgOp.getNumParallelLoops()); - SmallVector reductionIvs(linalgOp.getNumReductionLoops()); - SmallVector windowIvs(linalgOp.getNumWindowLoops()); - auto pivs = IndexHandle::makeIndexHandlePointers(parallelIvs); - auto rivs = IndexHandle::makeIndexHandlePointers(reductionIvs); - auto wivs = IndexHandle::makeIndexHandlePointers(windowIvs); - assert(loopRanges.size() == pivs.size() + rivs.size() + wivs.size()); + // The flattened loopToOperandRangesMaps is expected to be an invertible + // permutation map (which is asserted in the inverse calculation). + auto linalgOp = cast(op); + auto invertedMap = + inversePermutation(concatAffineMaps(loopToOperandRangesMaps(linalgOp))); + if (!invertedMap) { + LinalgScopedEmitter::emitScalarImplementation({}, linalgOp, + folder); + rewriter.replaceOp(op, {}); + return matchSuccess(); + } - // clang-format off - ArrayRef ranges(loopRanges); - LoopNestRangeBuilder(pivs, ranges.take_front(pivs.size()))([&] { - LoopNestRangeBuilder( - rivs, ranges.drop_back(wivs.size()).take_back(rivs.size()))([&] { - LoopNestRangeBuilder(wivs, ranges.take_back(wivs.size()))( - [&linalgOp, ¶llelIvs, &reductionIvs, &windowIvs, &state] { - SmallVector parallel( - parallelIvs.begin(), parallelIvs.end()); - SmallVector reduction( - reductionIvs.begin(), reductionIvs.end()); - SmallVector window( - windowIvs.begin(), windowIvs.end()); - mlir::linalg::emitScalarImplementation( - parallel, reduction, window, linalgOp, state); + auto nPar = linalgOp.getNumParallelLoops(); + auto nRed = linalgOp.getNumReductionLoops(); + auto nWin = linalgOp.getNumWindowLoops(); + SmallVector allIvs(nPar + nRed + nWin); + SmallVector allPIvs = makeIndexHandlePointers(allIvs); + auto pivs = MutableArrayRef(allPIvs).take_front(nPar); + auto rivs = MutableArrayRef(allPIvs) + .take_front(nPar + nRed) + .take_back(nRed); + auto wivs = MutableArrayRef(allPIvs).take_back(nWin); + + auto loopRanges = + emitLoopRanges(scope.getBuilder(), scope.getLocation(), invertedMap, + getViewSizes(linalgOp), folder); + assert(loopRanges.size() == pivs.size() + rivs.size() + wivs.size()); + + // clang-format off + ArrayRef ranges(loopRanges); + LoopNestRangeBuilder(pivs, ranges.take_front(nPar))([&] { + LoopNestRangeBuilder(rivs, ranges.drop_back(nWin).take_back(nRed))([&] { + LoopNestRangeBuilder(wivs, ranges.take_back(wivs.size()))( + [&linalgOp, &allIvs, this] { + auto allIvValues = extractValues(allIvs); + LinalgScopedEmitter::emitScalarImplementation( + allIvValues, linalgOp, folder); + }); }); }); - }); - // clang-format on + // clang-format on + rewriter.replaceOp(op, {}); + return matchSuccess(); + } + + mutable OperationFolder folder; +}; + +// Helper classes for type list expansion. +template class ConversionList; + +template <> class ConversionList<> { +public: + static void build(OwningRewritePatternList &patterns, MLIRContext *ctx) {} +}; + +template +class ConversionList { +public: + static void build(OwningRewritePatternList &patterns, MLIRContext *ctx) { + patterns.insert>(ctx); + ConversionList::build(patterns, ctx); + } +}; + +/// Populate the given list with patterns that convert from Linalg to LLVM. +static void +populateLinalgToLoopRewritePatterns(OwningRewritePatternList &patterns, + MLIRContext *ctx) { + ConversionList< +#define GET_OP_LIST +#include "mlir/Linalg/IR/LinalgLibraryOps.cpp.inc" + >::build(patterns, ctx); } namespace { @@ -104,11 +350,17 @@ struct LowerLinalgToLoopsPass : public FunctionPass { } // namespace void LowerLinalgToLoopsPass::runOnFunction() { - OperationFolder state; - getFunction().walk([&state](LinalgOp linalgOp) { - emitLinalgOpAsLoops(linalgOp, state); - linalgOp.getOperation()->erase(); - }); + OwningRewritePatternList patterns; + populateLinalgToLoopRewritePatterns(patterns, &getContext()); + + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + if (failed( + applyPartialConversion(getFunction(), target, std::move(patterns)))) { + signalPassFailure(); + } } FunctionPassBase *mlir::linalg::createLowerLinalgToLoopsPass() { diff --git a/third_party/mlir/lib/Linalg/Transforms/Tiling.cpp b/third_party/mlir/lib/Linalg/Transforms/Tiling.cpp index 25ffdebc61a..8090a587d42 100644 --- a/third_party/mlir/lib/Linalg/Transforms/Tiling.cpp +++ b/third_party/mlir/lib/Linalg/Transforms/Tiling.cpp @@ -381,7 +381,7 @@ mlir::linalg::tileLinalgOp(LinalgOp op, ArrayRef tileSizes, // 3. Create the tiled loops. LinalgOp res = op; SmallVector ivs(loopRanges.size()); - auto pivs = IndexHandle::makeIndexHandlePointers(ivs); + auto pivs = makeIndexHandlePointers(ivs); LoopNestRangeBuilder(pivs, loopRanges)([&] { auto b = ScopedContext::getBuilder(); auto loc = ScopedContext::getLocation(); diff --git a/third_party/mlir/lib/Linalg/Utils/Utils.cpp b/third_party/mlir/lib/Linalg/Utils/Utils.cpp index 850aefe0eae..d31fe0d3006 100644 --- a/third_party/mlir/lib/Linalg/Utils/Utils.cpp +++ b/third_party/mlir/lib/Linalg/Utils/Utils.cpp @@ -106,16 +106,6 @@ ValueHandle LoopNestRangeBuilder::LoopNestRangeBuilder::operator()( return ValueHandle::null(); } -SmallVector mlir::linalg::getViewSizes(LinalgOp &linalgOp) { - SmallVector res; - for (auto v : linalgOp.getInputsAndOutputs()) { - ViewType t = v->getType().cast(); - for (unsigned i = 0; i < t.getRank(); ++i) - res.push_back(linalg::intrinsics::dim(v, i)); - } - return res; -} - static Value *emitOrFoldComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef operandsRef, diff --git a/third_party/mlir/lib/Transforms/LowerVectorTransfers.cpp b/third_party/mlir/lib/Transforms/LowerVectorTransfers.cpp index ef67488023f..cda62d9ddc0 100644 --- a/third_party/mlir/lib/Transforms/LowerVectorTransfers.cpp +++ b/third_party/mlir/lib/Transforms/LowerVectorTransfers.cpp @@ -273,10 +273,9 @@ VectorTransferRewriter::matchAndRewrite( IndexedValue remote(transfer.getMemRef()); MemRefView view(transfer.getMemRef()); VectorView vectorView(transfer.getVector()); - SmallVector ivs = - IndexHandle::makeIndexHandles(vectorView.rank()); + SmallVector ivs = makeIndexHandles(vectorView.rank()); SmallVector pivs = - IndexHandle::makeIndexHandlePointers(ivs); + makeIndexHandlePointers(MutableArrayRef(ivs)); coalesceCopy(transfer, &pivs, &vectorView); auto lbs = vectorView.getLbs(); @@ -335,10 +334,8 @@ VectorTransferRewriter::matchAndRewrite( MemRefView view(transfer.getMemRef()); ValueHandle vectorValue(transfer.getVector()); VectorView vectorView(transfer.getVector()); - SmallVector ivs = - IndexHandle::makeIndexHandles(vectorView.rank()); - SmallVector pivs = - IndexHandle::makeIndexHandlePointers(ivs); + SmallVector ivs = makeIndexHandles(vectorView.rank()); + SmallVector pivs = makeIndexHandlePointers(ivs); coalesceCopy(transfer, &pivs, &vectorView); auto lbs = vectorView.getLbs();