Implement Linalg to loops lowering as a pattern

This CL rewrites the linalg ops to loops transformations as patterns that can be targeted directly from Tablegen. Reliance on OpFolder is removed and to cope with it we introduce local folding patterns that are applied greedily.

PiperOrigin-RevId: 282765550
Change-Id: I1cb9dd53a0364d965411b43c0ef1b52837e6af4a
This commit is contained in:
Nicolas Vasilache 2019-11-27 07:31:41 -08:00 committed by TensorFlower Gardener
parent 81f844c1ff
commit bb00fa819e
8 changed files with 261 additions and 129 deletions

View File

@ -46,7 +46,7 @@ Status LowerLHLOToGPU(mlir::ModuleOp module) {
// Transform element-wise operations to LinAlg.
pm.addPass(::mlir::xla_lhlo::createLegalizeToLinalgPass());
// Go from affine to normal loops.
pm.addPass(::mlir::linalg::createLowerLinalgToLoopsPass());
pm.addPass(::mlir::linalg::createConvertLinalgToLoopsPass());
// Lower affine to ordinary loops.
pm.addPass(::mlir::createLowerAffinePass());
// Move constants out of the loop.

View File

@ -2212,8 +2212,8 @@ cc_library(
"lib/Dialect/Linalg/IR/LinalgOps.cpp",
"lib/Dialect/Linalg/IR/LinalgTypes.cpp",
"lib/Dialect/Linalg/Transforms/Fusion.cpp",
"lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp",
"lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp",
"lib/Dialect/Linalg/Transforms/LowerToLoops.cpp",
"lib/Dialect/Linalg/Transforms/Promotion.cpp",
"lib/Dialect/Linalg/Transforms/Tiling.cpp",
"lib/Dialect/Linalg/Utils/Utils.cpp",

View File

@ -39,9 +39,16 @@ createLinalgTilingPass(ArrayRef<int64_t> tileSizes = {});
std::unique_ptr<OpPassBase<FuncOp>>
createLinalgPromotionPass(bool dynamicBuffers);
std::unique_ptr<OpPassBase<FuncOp>> createLowerLinalgToLoopsPass();
/// Create a pass to convert Linalg operations to loop.for loops and
/// std.load/std.store accesses.
std::unique_ptr<OpPassBase<FuncOp>> createConvertLinalgToLoopsPass();
/// Create a pass to convert vector operations to the LLVMIR dialect.
/// Create a pass to convert Linalg operations to affine.for loops and
/// affine_load/affine_store accesses.
/// Placeholder for now, this is NYI.
std::unique_ptr<OpPassBase<FuncOp>> createConvertLinalgToAffineLoopsPass();
/// Create a pass to convert Linalg operations to the LLVMIR dialect.
std::unique_ptr<OpPassBase<ModuleOp>> createConvertLinalgToLLVMPass();
} // namespace linalg

View File

@ -62,4 +62,15 @@ class TileLinalgOp<list<int> sizes, string value> : NativeCodeCall<
StrJoinInt<sizes>.result # "}, \"" # value # "\")))" #
" return matchFailure();">;
//===----------------------------------------------------------------------===//
// Linalg to loop patterns.
//===----------------------------------------------------------------------===//
class LinalgOpToLoops<string OpType> : NativeCodeCall<
"if (failed(linalgOpToLoops<" # OpType # ">($_builder, $0))) " #
" return matchFailure();">;
class LinalgOpToAffineLoops<string OpType> : NativeCodeCall<
"if (failed(linalgOpToAffineLoops<" # OpType # ">($_builder, $0))) " #
" return matchFailure();">;
#endif // LINALG_TRANSFORMS

View File

@ -35,20 +35,6 @@ struct LinalgTransforms {
static const StringLiteral kLinalgTransformMarker;
};
// Declarative transformation used in tablegen patterns.
// Tiles `op` by `sizes` and sets the attribute `kLinalgTransformMarker` to
// `linalgMarker`.
LogicalResult tileLinalgOpAndSetMarker(PatternRewriter &rewriter, Operation *op,
ArrayRef<int64_t> sizes,
StringRef linalgMarker);
// Declarative transformation used in tablegen patterns.
// Tiles `op` by `sizes`, fuses the producers of `operandIndicesToFuse` and sets
// the attribute `kLinalgTransformMarker` to `linalgMarker`.
LogicalResult tileAndFuseLinalgOpAndSetMarker(
PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t> sizes,
ArrayRef<int64_t> operandIndicesToFuse, StringRef linalgMarker);
namespace detail {
// Implementation detail of isProducedByOpOfType avoids the need for explicit
// template instantiations.
@ -65,6 +51,33 @@ bool isProducedByOpOfType(Operation *consumerOp, Value *consumedView) {
consumerOp, consumedView, [](Operation *op) { return isa<OpTy>(op); });
}
////////////////////////////////////////////////////////////////////////////////
// The following Declarative Rewrite Rule (DRR) helpers are used in rewrite
// patterns. As such, they must not call into `rewriter.erase/replace` APIs and
// it is the responsibility of the enclosing PatternRewriter to erase on
// success.
////////////////////////////////////////////////////////////////////////////////
// Tiles `op` by `sizes` and sets the attribute `kLinalgTransformMarker` to
// `linalgMarker`.
LogicalResult tileLinalgOpAndSetMarker(PatternRewriter &rewriter, Operation *op,
ArrayRef<int64_t> sizes,
StringRef linalgMarker);
// Tiles `op` by `sizes`, fuses the producers of `operandIndicesToFuse` and sets
// the attribute `kLinalgTransformMarker` to `linalgMarker`.
LogicalResult tileAndFuseLinalgOpAndSetMarker(
PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t> sizes,
ArrayRef<int64_t> operandIndicesToFuse, StringRef linalgMarker);
// Emits a loop nest of `loop.for` with the proper body for `op`.
template <typename ConcreteOp>
LogicalResult linalgOpToLoops(PatternRewriter &rewriter, Operation *op);
// Emits a loop nest of `affine.for` with the proper body for `op`.
template <typename ConcreteOp>
LogicalResult linalgOpToAffineLoops(PatternRewriter &rewriter, Operation *op);
} // namespace linalg
} // namespace mlir

View File

@ -5,7 +5,7 @@ add_llvm_library(MLIRLinalg
IR/LinalgTypes.cpp
Transforms/Fusion.cpp
Transforms/LinalgTransforms.cpp
Transforms/LowerToLoops.cpp
Transforms/LinalgToLoops.cpp
Transforms/Promotion.cpp
Transforms/Tiling.cpp
Utils/Utils.cpp

View File

@ -19,6 +19,7 @@
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/LinalgTransforms.h"
#include "mlir/Dialect/Linalg/Utils/Intrinsics.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/LoopOps/LoopOps.h"
@ -41,12 +42,14 @@ using namespace mlir::linalg;
using namespace mlir::linalg::intrinsics;
using IndexedStdValue = TemplatedIndexedValue<std_load, std_store>;
using IndexedAffineValue = TemplatedIndexedValue<affine_load, affine_store>;
using edsc::op::operator+;
using edsc::op::operator==;
static SmallVector<ValueHandle, 8>
foldedAffineApplies(OpBuilder &b, Location loc, AffineMap map,
ArrayRef<Value *> vals, OperationFolder *folder) {
makeCanonicalAffineApplies(OpBuilder &b, Location loc, AffineMap map,
ArrayRef<Value *> vals) {
assert(map.getNumSymbols() == 0);
assert(map.getNumInputs() == vals.size());
SmallVector<ValueHandle, 8> res;
@ -56,17 +59,16 @@ foldedAffineApplies(OpBuilder &b, Location loc, AffineMap map,
auto exprMap = AffineMap::get(dims, 0, e);
SmallVector<Value *, 4> operands(vals.begin(), vals.end());
canonicalizeMapAndOperands(&exprMap, &operands);
res.push_back(affine_apply(folder, exprMap, operands));
res.push_back(affine_apply(exprMap, operands));
}
return res;
}
static SmallVector<Value *, 4> permuteIvs(ArrayRef<Value *> ivs,
Optional<AffineMap> permutation,
OperationFolder *folder) {
Optional<AffineMap> permutation) {
return permutation ? applyMapToValues(ScopedContext::getBuilder(),
ScopedContext::getLocation(),
permutation.getValue(), ivs, folder)
permutation.getValue(), ivs)
: SmallVector<Value *, 4>(ivs.begin(), ivs.end());
}
@ -75,20 +77,17 @@ static SmallVector<Value *, 4> permuteIvs(ArrayRef<Value *> ivs,
// which new loops will be created.
static SmallVector<Value *, 4> emitLoopRanges(OpBuilder &b, Location loc,
AffineMap map,
ArrayRef<Value *> allViewSizes,
OperationFolder *folder);
ArrayRef<Value *> allViewSizes);
SmallVector<Value *, 4> emitLoopRanges(OpBuilder &b, Location loc,
AffineMap map,
ArrayRef<Value *> allViewSizes,
OperationFolder *folder) {
ArrayRef<Value *> allViewSizes) {
// Apply `map` to get view sizes in loop order.
auto sizes = applyMapToValues(b, loc, map, allViewSizes, folder);
auto sizes = applyMapToValues(b, loc, map, allViewSizes);
// Create a new range with the applied tile sizes.
ScopedContext scope(b, loc);
SmallVector<Value *, 4> res;
for (unsigned idx = 0, e = map.getNumResults(); idx < e; ++idx) {
res.push_back(range(constant_index(folder, 0), sizes[idx],
constant_index(folder, 1)));
res.push_back(range(constant_index(0), sizes[idx], constant_index(1)));
}
return res;
}
@ -99,14 +98,14 @@ class LinalgScopedEmitter {};
template <typename IndexedValueType>
class LinalgScopedEmitter<IndexedValueType, CopyOp> {
public:
static void emitScalarImplementation(ArrayRef<Value *> allIvs, CopyOp copyOp,
OperationFolder *folder) {
static void emitScalarImplementation(ArrayRef<Value *> allIvs,
CopyOp copyOp) {
auto nPar = copyOp.getNumParallelLoops();
assert(nPar == allIvs.size());
auto inputIvs =
permuteIvs(allIvs.take_front(nPar), copyOp.inputPermutation(), folder);
permuteIvs(allIvs.take_front(nPar), copyOp.inputPermutation());
auto outputIvs =
permuteIvs(allIvs.take_front(nPar), copyOp.outputPermutation(), folder);
permuteIvs(allIvs.take_front(nPar), copyOp.outputPermutation());
SmallVector<IndexHandle, 8> iivs(inputIvs.begin(), inputIvs.end());
SmallVector<IndexHandle, 8> oivs(outputIvs.begin(), outputIvs.end());
IndexedValueType O(copyOp.getOutput(0)), I(copyOp.getInput(0));
@ -122,8 +121,8 @@ public:
template <typename IndexedValueType>
class LinalgScopedEmitter<IndexedValueType, FillOp> {
public:
static void emitScalarImplementation(ArrayRef<Value *> allIvs, FillOp fillOp,
OperationFolder *folder) {
static void emitScalarImplementation(ArrayRef<Value *> allIvs,
FillOp fillOp) {
auto nPar = fillOp.getNumParallelLoops();
assert(nPar == allIvs.size());
auto ivs =
@ -139,8 +138,7 @@ public:
template <typename IndexedValueType>
class LinalgScopedEmitter<IndexedValueType, DotOp> {
public:
static void emitScalarImplementation(ArrayRef<Value *> allIvs, DotOp dotOp,
OperationFolder *folder) {
static void emitScalarImplementation(ArrayRef<Value *> allIvs, DotOp dotOp) {
assert(allIvs.size() == 1);
IndexHandle r_i(allIvs[0]);
IndexedValueType A(dotOp.getInput(0)), B(dotOp.getInput(1)),
@ -154,8 +152,7 @@ template <typename IndexedValueType>
class LinalgScopedEmitter<IndexedValueType, MatvecOp> {
public:
static void emitScalarImplementation(ArrayRef<Value *> allIvs,
MatvecOp matvecOp,
OperationFolder *folder) {
MatvecOp matvecOp) {
assert(allIvs.size() == 2);
IndexHandle i(allIvs[0]), r_j(allIvs[1]);
IndexedValueType A(matvecOp.getInput(0)), B(matvecOp.getInput(1)),
@ -169,8 +166,7 @@ template <typename IndexedValueType>
class LinalgScopedEmitter<IndexedValueType, MatmulOp> {
public:
static void emitScalarImplementation(ArrayRef<Value *> allIvs,
MatmulOp matmulOp,
OperationFolder *folder) {
MatmulOp matmulOp) {
assert(allIvs.size() == 3);
IndexHandle i(allIvs[0]), j(allIvs[1]), r_k(allIvs[2]);
IndexedValueType A(matmulOp.getInput(0)), B(matmulOp.getInput(1)),
@ -183,17 +179,17 @@ public:
template <typename IndexedValueType>
class LinalgScopedEmitter<IndexedValueType, ConvOp> {
public:
static void emitScalarImplementation(ArrayRef<Value *> allIvs, ConvOp convOp,
OperationFolder *folder) {
static void emitScalarImplementation(ArrayRef<Value *> allIvs,
ConvOp convOp) {
auto b = ScopedContext::getBuilder();
auto loc = ScopedContext::getLocation();
auto maps = loopToOperandRangesMaps(convOp);
SmallVector<ValueHandle, 8> fIdx(
foldedAffineApplies(b, loc, maps[0], allIvs, folder));
makeCanonicalAffineApplies(b, loc, maps[0], allIvs));
SmallVector<ValueHandle, 8> imIdx(
foldedAffineApplies(b, loc, maps[1], allIvs, folder));
makeCanonicalAffineApplies(b, loc, maps[1], allIvs));
SmallVector<ValueHandle, 8> oIdx(
foldedAffineApplies(b, loc, maps[2], allIvs, folder));
makeCanonicalAffineApplies(b, loc, maps[2], allIvs));
IndexedValueType F(convOp.filter()), I(convOp.input()), O(convOp.output());
// Emit scalar form.
O(oIdx) += F(fIdx) * I(imIdx);
@ -234,8 +230,7 @@ template <typename IndexedValueType>
class LinalgScopedEmitter<IndexedValueType, GenericOp> {
public:
static void emitScalarImplementation(ArrayRef<Value *> allIvs,
GenericOp genericOp,
OperationFolder *folder) {
GenericOp genericOp) {
auto b = ScopedContext::getBuilder();
auto loc = ScopedContext::getLocation();
using edsc::intrinsics::detail::ValueHandleArray;
@ -245,15 +240,15 @@ public:
// 1.a. Emit std_load from input views.
for (unsigned i = 0; i < nInputs; ++i) {
ValueHandleArray indexing(foldedAffineApplies(
b, loc, genericOp.getInputIndexingMap(i), allIvs, folder));
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, genericOp.getInputIndexingMap(i), allIvs));
indexedValues[i] = std_load(genericOp.getInput(i), indexing);
}
// 1.b. Emit std_load from output views.
for (unsigned i = 0; i < nOutputs; ++i) {
ValueHandleArray indexing(foldedAffineApplies(
b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder));
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, genericOp.getOutputIndexingMap(i), allIvs));
indexedValues[nInputs + i] = std_load(genericOp.getOutput(i), indexing);
}
@ -265,8 +260,8 @@ public:
// 3. Emit std_store.
for (unsigned i = 0; i < nOutputs; ++i) {
ValueHandleArray indexing(foldedAffineApplies(
b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder));
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, genericOp.getOutputIndexingMap(i), allIvs));
std_store(callOp->getResult(i), genericOp.getOutput(i), indexing);
}
return;
@ -288,8 +283,8 @@ public:
auto *yieldOp = cast<YieldOp>(block.back()).getOperation();
assert(yieldOp->getNumOperands() == nOutputs);
for (unsigned i = 0; i < nOutputs; ++i) {
ValueHandleArray indexing(foldedAffineApplies(
b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder));
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, genericOp.getOutputIndexingMap(i), allIvs));
std_store(map.lookup(yieldOp->getOperand(i)), genericOp.getOutput(i),
indexing);
}
@ -330,8 +325,7 @@ template <typename IndexedValueType>
class LinalgScopedEmitter<IndexedValueType, IndexedGenericOp> {
public:
static void emitScalarImplementation(ArrayRef<Value *> allIvs,
IndexedGenericOp indexedGenericOp,
OperationFolder *folder) {
IndexedGenericOp indexedGenericOp) {
auto b = ScopedContext::getBuilder();
auto loc = ScopedContext::getLocation();
using edsc::intrinsics::detail::ValueHandleArray;
@ -346,16 +340,16 @@ public:
// 1.a. Emit std_load from input views.
for (unsigned i = 0; i < nInputs; ++i) {
ValueHandleArray indexing(foldedAffineApplies(
b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs, folder));
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs));
indexedValues[nLoops + i] =
std_load(indexedGenericOp.getInput(i), indexing);
}
// 1.b. Emit std_load from output views.
for (unsigned i = 0; i < nOutputs; ++i) {
ValueHandleArray indexing(foldedAffineApplies(
b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs, folder));
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
indexedValues[nLoops + nInputs + i] =
std_load(indexedGenericOp.getOutput(i), indexing);
}
@ -367,8 +361,8 @@ public:
// 3. Emit std_store.
for (unsigned i = 0; i < nOutputs; ++i) {
ValueHandleArray indexing(foldedAffineApplies(
b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs, folder));
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
std_store(callOp->getResult(i), indexedGenericOp.getOutput(i),
indexing);
}
@ -391,96 +385,110 @@ public:
auto *yieldOp = cast<YieldOp>(block.back()).getOperation();
assert(yieldOp->getNumOperands() == nOutputs);
for (unsigned i = 0; i < nOutputs; ++i) {
ValueHandleArray indexing(foldedAffineApplies(
b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs, folder));
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
std_store(map.lookup(yieldOp->getOperand(i)),
indexedGenericOp.getOutput(i), indexing);
}
}
};
namespace {
// This struct is for factoring out the implementation and support template
// instantiations in the following 2 cases:
// 1. Appending to a list of patterns via RewritePatternList.
// 2. Direct invocation via `linalgOpToLoops` and `linalgOpToAffineLoops`.
// The implementation must work both in DRR and inside a RewritePattern. As a
// consequence, (1) it is only allowed to emit new ops if the match is
// guaranteed to be a success, (2) it is not allowed erase/replace, and (3) an
// encompassing pattern must take care of the erasure logic.
template <typename LoopTy, typename IndexedValueTy, typename ConcreteOpTy>
class LinalgOpToLoopsImpl {
public:
static LogicalResult doit(Operation *op, PatternRewriter &rewriter);
};
} // namespace
template <typename LoopTy, typename IndexedValueTy, typename ConcreteOpTy>
LogicalResult LinalgOpToLoopsImpl<LoopTy, IndexedValueTy, ConcreteOpTy>::doit(
Operation *op, PatternRewriter &rewriter) {
OpBuilder b(op);
ScopedContext scope(b, op->getLoc());
// The flattened loopToOperandRangesMaps is expected to be an invertible
// permutation map (which is asserted in the inverse calculation).
auto linalgOp = cast<ConcreteOpTy>(op);
auto invertedMap =
inversePermutation(concatAffineMaps(loopToOperandRangesMaps(linalgOp)));
if (!invertedMap) {
LinalgScopedEmitter<IndexedValueTy, ConcreteOpTy>::emitScalarImplementation(
{}, linalgOp);
return success();
}
auto nPar = linalgOp.getNumParallelLoops();
auto nRed = linalgOp.getNumReductionLoops();
auto nWin = linalgOp.getNumWindowLoops();
SmallVector<IndexHandle, 4> allIvs(nPar + nRed + nWin);
SmallVector<ValueHandle *, 4> allPIvs = makeIndexHandlePointers(allIvs);
auto loopRanges = emitLoopRanges(scope.getBuilder(), scope.getLocation(),
invertedMap, getViewSizes(linalgOp));
assert(loopRanges.size() == allIvs.size());
LoopNestRangeBuilder(allPIvs, loopRanges)([&] {
auto allIvValues = extractValues(allIvs);
LinalgScopedEmitter<IndexedValueTy, ConcreteOpTy>::emitScalarImplementation(
allIvValues, linalgOp);
});
return success();
}
template <typename LoopType, typename IndexedValueType, typename ConcreteOp>
class LinalgRewritePattern : public RewritePattern {
public:
explicit LinalgRewritePattern(MLIRContext *context)
: RewritePattern(ConcreteOp::getOperationName(), /*benefit=*/1, context),
folder(context) {}
: RewritePattern(ConcreteOp::getOperationName(), 1, context) {}
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
OpBuilder b(op);
ScopedContext scope(b, op->getLoc());
// The flattened loopToOperandRangesMaps is expected to be an invertible
// permutation map (which is asserted in the inverse calculation).
auto linalgOp = cast<ConcreteOp>(op);
auto invertedMap =
inversePermutation(concatAffineMaps(loopToOperandRangesMaps(linalgOp)));
if (!invertedMap) {
LinalgScopedEmitter<IndexedValueType,
ConcreteOp>::emitScalarImplementation({}, linalgOp,
&folder);
rewriter.eraseOp(op);
return matchSuccess();
}
auto nPar = linalgOp.getNumParallelLoops();
auto nRed = linalgOp.getNumReductionLoops();
auto nWin = linalgOp.getNumWindowLoops();
SmallVector<IndexHandle, 4> allIvs(nPar + nRed + nWin);
SmallVector<ValueHandle *, 4> allPIvs = makeIndexHandlePointers(allIvs);
auto loopRanges =
emitLoopRanges(scope.getBuilder(), scope.getLocation(), invertedMap,
getViewSizes(linalgOp), &folder);
assert(loopRanges.size() == allIvs.size());
// clang-format off;
LoopNestRangeBuilder(allPIvs, loopRanges)([&] {
auto allIvValues = extractValues(allIvs);
LinalgScopedEmitter<IndexedValueType,
ConcreteOp>::emitScalarImplementation(allIvValues,
linalgOp,
&folder);
});
// clang-format on
using Impl = LinalgOpToLoopsImpl<LoopType, IndexedValueType, ConcreteOp>;
if (failed(Impl::doit(op, rewriter)))
return matchFailure();
rewriter.eraseOp(op);
return matchSuccess();
}
mutable OperationFolder folder;
};
// Helper classes for type list expansion.
template <typename LoopType, typename IndexedValueType, typename... LinalgOps>
class ConversionList;
class RewritePatternList;
template <typename LoopType, typename IndexedValueType>
class ConversionList<LoopType, IndexedValueType> {
class RewritePatternList<LoopType, IndexedValueType> {
public:
static void build(OwningRewritePatternList &patterns, MLIRContext *ctx) {}
};
template <typename LoopType, typename IndexedValueType, typename ConcreteOp,
typename... LinalgOps>
class ConversionList<LoopType, IndexedValueType, ConcreteOp, LinalgOps...> {
class RewritePatternList<LoopType, IndexedValueType, ConcreteOp, LinalgOps...> {
public:
static void build(OwningRewritePatternList &patterns, MLIRContext *ctx) {
patterns
.insert<LinalgRewritePattern<LoopType, IndexedValueType, ConcreteOp>>(
ctx);
ConversionList<LoopType, IndexedValueType, LinalgOps...>::build(patterns,
ctx);
RewritePatternList<LoopType, IndexedValueType, LinalgOps...>::build(
patterns, ctx);
}
};
/// Populate the given list with patterns that convert from Linalg to LLVM.
template <typename LoopType, typename IndexedValueType>
void ForOpRewritePatterns(OwningRewritePatternList &patterns,
MLIRContext *ctx) {
ConversionList<LoopType, IndexedValueType,
void FillRewritePatterns(OwningRewritePatternList &patterns, MLIRContext *ctx) {
RewritePatternList<LoopType, IndexedValueType,
#define GET_OP_LIST
#include "mlir/Dialect/Linalg/IR/LinalgLibraryOps.cpp.inc"
>::build(patterns, ctx);
>::build(patterns, ctx);
}
namespace {
@ -491,28 +499,114 @@ struct LowerLinalgToLoopsPass
};
} // namespace
// Local folding pattern for AffineApplyOp that we can apply greedily.
// This replaces AffineApplyOp by the proper value in cases where the associated
// map is trivial. A trivial map here is defined as a map with a single result
// and either:
// 1. Zero operand + returns a single AffineConstantExpr
// 2. One operand + returns a single AffineDimExpr
// 3. One operands + returns a single AffineSymbolExpr
//
// In the first case, the AffineApplyOp is replaced by a new constant. In the
// other cases, it is replaced by its unique operand.
struct FoldAffineOp : public RewritePattern {
FoldAffineOp(MLIRContext *context)
: RewritePattern(AffineApplyOp::getOperationName(), 0, context) {}
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
AffineApplyOp affineApplyOp = cast<AffineApplyOp>(op);
auto map = affineApplyOp.getAffineMap();
if (map.getNumResults() != 1 || map.getNumInputs() > 1)
return matchFailure();
AffineExpr expr = map.getResult(0);
if (map.getNumInputs() == 0) {
if (auto val = expr.dyn_cast<AffineConstantExpr>()) {
rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, val.getValue());
return matchSuccess();
}
return matchFailure();
}
if (expr.dyn_cast<AffineDimExpr>() || expr.dyn_cast<AffineSymbolExpr>()) {
rewriter.replaceOp(op, op->getOperand(0));
return matchSuccess();
}
return matchFailure();
}
};
template <typename LoopType, typename IndexedValueType>
void LowerLinalgToLoopsPass<LoopType, IndexedValueType>::runOnFunction() {
auto *context = &this->getContext();
OwningRewritePatternList patterns;
ForOpRewritePatterns<LoopType, IndexedValueType>(patterns,
&this->getContext());
ConversionTarget target(this->getContext());
target.addLegalDialect<AffineOpsDialect>();
target.addLegalDialect<loop::LoopOpsDialect>();
target.addLegalDialect<StandardOpsDialect>();
if (failed(applyPartialConversion(this->getFunction(), target, patterns))) {
this->signalPassFailure();
}
// Canonicalization and folding patterns applied greedily allow cleaning up
// the emitted IR on the fly.
// TODO(ntv) fold view and subview ops?
FillRewritePatterns<LoopType, IndexedValueType>(patterns, context);
DimOp::getCanonicalizationPatterns(patterns, context);
AffineApplyOp::getCanonicalizationPatterns(patterns, context);
patterns.insert<FoldAffineOp>(context);
// Just apply the patterns greedily.
applyPatternsGreedily(this->getFunction(), patterns);
}
/// Create a pass to convert Linalg operations to loop.for loops and
/// std.load/std.store accesses.
std::unique_ptr<OpPassBase<FuncOp>>
mlir::linalg::createLowerLinalgToLoopsPass() {
mlir::linalg::createConvertLinalgToLoopsPass() {
return std::make_unique<
LowerLinalgToLoopsPass<loop::ForOp, IndexedStdValue>>();
}
/// Create a pass to convert Linalg operations to affine.for loops and
/// affine_load/affine_store accesses.
/// Placeholder for now, this is NYI.
std::unique_ptr<OpPassBase<FuncOp>>
mlir::linalg::createConvertLinalgToAffineLoopsPass() {
return std::make_unique<
LowerLinalgToLoopsPass<AffineForOp, IndexedAffineValue>>();
}
// Emits a loop nest of `loop.for` with the proper body for `op`.
template <typename ConcreteOp>
LogicalResult mlir::linalg::linalgOpToLoops(PatternRewriter &rewriter,
Operation *op) {
return LinalgOpToLoopsImpl<loop::ForOp, IndexedStdValue, ConcreteOp>::doit(
op, rewriter);
}
// Emits a loop nest of `affine.for` with the proper body for `op`.
template <typename ConcreteOp>
LogicalResult mlir::linalg::linalgOpToAffineLoops(PatternRewriter &rewriter,
Operation *op) {
return LinalgOpToLoopsImpl<AffineForOp, IndexedAffineValue, ConcreteOp>::doit(
op, rewriter);
}
// TODO(ntv) Need to make these instantiations more future-proof to avoid the
// need to update as soon as we add new ops.
#define INSTANTIATE_LINALG_OP_TO_LOOPS(OP_TYPE) \
template LogicalResult mlir::linalg::linalgOpToLoops<OP_TYPE>( \
PatternRewriter & rewriter, Operation * op); \
template LogicalResult mlir::linalg::linalgOpToAffineLoops<OP_TYPE>( \
PatternRewriter & rewriter, Operation * op);
INSTANTIATE_LINALG_OP_TO_LOOPS(CopyOp)
INSTANTIATE_LINALG_OP_TO_LOOPS(FillOp)
INSTANTIATE_LINALG_OP_TO_LOOPS(DotOp)
INSTANTIATE_LINALG_OP_TO_LOOPS(MatvecOp)
INSTANTIATE_LINALG_OP_TO_LOOPS(MatmulOp)
INSTANTIATE_LINALG_OP_TO_LOOPS(ConvOp)
INSTANTIATE_LINALG_OP_TO_LOOPS(GenericOp)
INSTANTIATE_LINALG_OP_TO_LOOPS(IndexedGenericOp)
static PassRegistration<LowerLinalgToLoopsPass<loop::ForOp, IndexedStdValue>>
structuredLoopsPass(
"linalg-lower-to-loops",
"convert-linalg-to-loops",
"Lower the operations from the linalg dialect into loops");
static PassRegistration<LowerLinalgToLoopsPass<AffineForOp, IndexedAffineValue>>
affineLoopsPass(
"convert-linalg-to-affine-loops",
"Lower the operations from the linalg dialect into affine loops");

View File

@ -73,4 +73,11 @@ def : Pattern<(DotOp:$op $a, $b, $c),
[(TileLinalgOp<[8], "REG"> $op)],
[(Constraint<HasLinalgTransformMarker<"L1">> $op)]>;
//===----------------------------------------------------------------------===//
// Linalg to loops patterns.
//===----------------------------------------------------------------------===//
def : Pattern<(DotOp:$op $a, $b, $c),
[(LinalgOpToLoops<"DotOp"> $op)],
[(Constraint<HasLinalgTransformMarker<"REG">> $op)]>;
#endif // TEST_LINALG_TRANSFORMS_PATTERNS