Refactor Linalg ops to loop lowering (NFC)
This CL modifies the LowerLinalgToLoopsPass to use RewritePattern. This will make it easier to inline Linalg generic functions and regions when emitting to loops in a subsequent CL. PiperOrigin-RevId: 261894120
This commit is contained in:
parent
ecc0b95092
commit
6c553ffc4d
1
third_party/mlir/BUILD
vendored
1
third_party/mlir/BUILD
vendored
@ -1738,6 +1738,7 @@ cc_library(
|
||||
"include/mlir/Linalg/Utils/Utils.h",
|
||||
],
|
||||
deps = [
|
||||
":AffineOps",
|
||||
":CFGTransforms",
|
||||
":EDSC",
|
||||
":IR",
|
||||
|
36
third_party/mlir/include/mlir/EDSC/Intrinsics.h
vendored
36
third_party/mlir/include/mlir/EDSC/Intrinsics.h
vendored
@ -61,20 +61,32 @@ struct IndexHandle : public ValueHandle {
|
||||
this->v = v.getValue();
|
||||
return *this;
|
||||
}
|
||||
static SmallVector<IndexHandle, 8> makeIndexHandles(unsigned rank) {
|
||||
return SmallVector<IndexHandle, 8>(rank);
|
||||
}
|
||||
static SmallVector<ValueHandle *, 8>
|
||||
makeIndexHandlePointers(SmallVectorImpl<IndexHandle> &ivs) {
|
||||
SmallVector<ValueHandle *, 8> pivs;
|
||||
pivs.reserve(ivs.size());
|
||||
for (auto &iv : ivs) {
|
||||
pivs.push_back(&iv);
|
||||
}
|
||||
return pivs;
|
||||
}
|
||||
};
|
||||
|
||||
inline SmallVector<IndexHandle, 8> makeIndexHandles(unsigned rank) {
|
||||
return SmallVector<IndexHandle, 8>(rank);
|
||||
}
|
||||
|
||||
inline SmallVector<ValueHandle *, 8>
|
||||
makeIndexHandlePointers(MutableArrayRef<IndexHandle> ivs) {
|
||||
SmallVector<ValueHandle *, 8> pivs;
|
||||
pivs.reserve(ivs.size());
|
||||
for (auto &iv : ivs) {
|
||||
pivs.push_back(&iv);
|
||||
}
|
||||
return pivs;
|
||||
}
|
||||
|
||||
/// Returns a vector of the underlying Value* from `ivs`.
|
||||
inline SmallVector<Value *, 8> extractValues(ArrayRef<IndexHandle> ivs) {
|
||||
SmallVector<Value *, 8> vals;
|
||||
vals.reserve(ivs.size());
|
||||
for (auto &iv : ivs) {
|
||||
vals.push_back(iv.getValue());
|
||||
}
|
||||
return vals;
|
||||
}
|
||||
|
||||
/// Provides a set of first class intrinsics.
|
||||
/// In the future, most of intrinsics related to Operation that don't contain
|
||||
/// other operations should be Tablegen'd.
|
||||
|
@ -436,11 +436,6 @@ private:
|
||||
};
|
||||
};
|
||||
|
||||
void emitScalarImplementation(llvm::ArrayRef<Value *> parallelIvs,
|
||||
llvm::ArrayRef<Value *> reductionIvs,
|
||||
llvm::ArrayRef<Value *> windowIvs,
|
||||
LinalgOp &linalgOp, OperationFolder &folder);
|
||||
|
||||
} // namespace linalg
|
||||
} // namespace mlir
|
||||
|
||||
|
@ -27,8 +27,10 @@ class BufferDeallocOp;
|
||||
class CopyOp;
|
||||
class DimOp;
|
||||
class FillOp;
|
||||
class LoadOp;
|
||||
class RangeOp;
|
||||
class SliceOp;
|
||||
class StoreOp;
|
||||
class ViewOp;
|
||||
namespace intrinsics {
|
||||
using buffer_alloc = mlir::edsc::intrinsics::ValueBuilder<BufferAllocOp>;
|
||||
@ -37,6 +39,8 @@ using buffer_dealloc =
|
||||
using copy = mlir::edsc::intrinsics::OperationBuilder<CopyOp>;
|
||||
using dim = mlir::edsc::intrinsics::ValueBuilder<linalg::DimOp>;
|
||||
using fill = mlir::edsc::intrinsics::OperationBuilder<FillOp>;
|
||||
using linalg_load = mlir::edsc::intrinsics::ValueBuilder<linalg::LoadOp>;
|
||||
using linalg_store = mlir::edsc::intrinsics::OperationBuilder<linalg::StoreOp>;
|
||||
using range = mlir::edsc::intrinsics::ValueBuilder<RangeOp>;
|
||||
using slice = mlir::edsc::intrinsics::ValueBuilder<SliceOp>;
|
||||
using view = mlir::edsc::intrinsics::ValueBuilder<ViewOp>;
|
||||
|
@ -21,6 +21,7 @@
|
||||
#include "mlir/Dialect/LoopOps/LoopOps.h"
|
||||
#include "mlir/EDSC/Helpers.h"
|
||||
#include "mlir/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Linalg/Utils/Intrinsics.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
namespace mlir {
|
||||
@ -79,7 +80,16 @@ namespace linalg {
|
||||
/// Returns the linearized list of all view dimensions in a linalgOp. Applying
|
||||
/// the inverse, concatenated loopToOperandRangeMaps to this list allows the
|
||||
/// derivation of loop ranges for any linalgOp.
|
||||
SmallVector<Value *, 8> getViewSizes(LinalgOp &linalgOp);
|
||||
template <typename ConcreteOp>
|
||||
SmallVector<Value *, 8> getViewSizes(ConcreteOp linalgOp) {
|
||||
SmallVector<Value *, 8> res;
|
||||
for (auto v : linalgOp.getInputsAndOutputs()) {
|
||||
ViewType t = v->getType().template cast<ViewType>();
|
||||
for (unsigned i = 0; i < t.getRank(); ++i)
|
||||
res.push_back(intrinsics::dim(v, i));
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
/// Returns the values obtained by applying `map` to the list of values.
|
||||
/// Performs simplifications and foldings where possible.
|
||||
|
9
third_party/mlir/lib/Linalg/CMakeLists.txt
vendored
9
third_party/mlir/lib/Linalg/CMakeLists.txt
vendored
@ -14,4 +14,11 @@ add_llvm_library(MLIRLinalg
|
||||
DEPENDS
|
||||
intrinsics_gen
|
||||
)
|
||||
add_dependencies(MLIRLinalg MLIRLinalgOpsIncGen MLIRLinalgLibraryOpsIncGen MLIRStandardToLLVM)
|
||||
|
||||
add_dependencies(MLIRLinalg
|
||||
|
||||
MLIRAffineOps
|
||||
MLIRLinalgOpsIncGen
|
||||
MLIRLinalgLibraryOpsIncGen
|
||||
MLIRStandardToLLVM
|
||||
)
|
||||
|
178
third_party/mlir/lib/Linalg/IR/LinalgOps.cpp
vendored
178
third_party/mlir/lib/Linalg/IR/LinalgOps.cpp
vendored
@ -846,23 +846,6 @@ static SmallVector<AffineExpr, 4> concat(ArrayRef<AffineExpr> a,
|
||||
return res;
|
||||
}
|
||||
|
||||
static SmallVector<ValueHandle, 8>
|
||||
foldedAffineApplies(OpBuilder &b, Location loc, AffineMap map,
|
||||
ArrayRef<Value *> vals, OperationFolder &folder) {
|
||||
assert(map.getNumSymbols() == 0);
|
||||
assert(map.getNumInputs() == vals.size());
|
||||
SmallVector<ValueHandle, 8> res;
|
||||
res.reserve(map.getNumResults());
|
||||
auto dims = map.getNumDims();
|
||||
for (auto e : map.getResults()) {
|
||||
auto exprMap = AffineMap::get(dims, 0, e);
|
||||
SmallVector<Value *, 4> operands(vals.begin(), vals.end());
|
||||
canonicalizeMapAndOperands(&exprMap, &operands);
|
||||
res.push_back(affine_apply(folder, exprMap, operands));
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
// Note: both functions below would completely disappear with a simple tensor
|
||||
// kernel language.
|
||||
//
|
||||
@ -950,164 +933,3 @@ SmallVector<AffineMap, 4> mlir::linalg::loopToOperandRangesMaps(Operation *op) {
|
||||
}
|
||||
llvm_unreachable("Missing loopToOperandRangesMaps for op");
|
||||
}
|
||||
|
||||
static SmallVector<Value *, 4> permuteIvs(ArrayRef<Value *> ivs,
|
||||
Optional<AffineMap> permutation,
|
||||
OperationFolder &state) {
|
||||
return permutation ? applyMapToValues(ScopedContext::getBuilder(),
|
||||
ScopedContext::getLocation(),
|
||||
permutation.getValue(), ivs, state)
|
||||
: SmallVector<Value *, 4>(ivs.begin(), ivs.end());
|
||||
}
|
||||
|
||||
// Ideally this should all be Tablegen'd but there is no good story for op
|
||||
// expansion directly in MLIR for now.
|
||||
void mlir::linalg::emitScalarImplementation(
|
||||
llvm::ArrayRef<Value *> parallelIvs, llvm::ArrayRef<Value *> reductionIvs,
|
||||
llvm::ArrayRef<Value *> windowIvs, LinalgOp &linalgOp,
|
||||
OperationFolder &folder) {
|
||||
using linalg_load = ValueBuilder<linalg::LoadOp>;
|
||||
using linalg_store = OperationBuilder<linalg::StoreOp>;
|
||||
using IndexedValue = TemplatedIndexedValue<linalg_load, linalg_store>;
|
||||
using edsc::op::operator+;
|
||||
using edsc::op::operator*;
|
||||
using edsc::op::operator==;
|
||||
using edsc::intrinsics::select;
|
||||
|
||||
auto nPar = parallelIvs.size();
|
||||
auto nRed = reductionIvs.size();
|
||||
auto nWin = windowIvs.size();
|
||||
SmallVector<Value *, 8> allIvs;
|
||||
allIvs.reserve(nPar + nRed + nWin);
|
||||
allIvs.assign(parallelIvs.begin(), parallelIvs.end());
|
||||
allIvs.append(reductionIvs.begin(), reductionIvs.end());
|
||||
allIvs.append(windowIvs.begin(), windowIvs.end());
|
||||
|
||||
// Default OpBuilder supports 0-D case (no loops).
|
||||
OpBuilder b(linalgOp.getOperation());
|
||||
auto nLoops = nPar + nRed + nWin;
|
||||
if (nLoops > 0) {
|
||||
auto innermostLoop = loop::getForInductionVarOwner(allIvs.back());
|
||||
// accounts for linalg.terminator in loop.
|
||||
b = innermostLoop.getBodyBuilder();
|
||||
}
|
||||
|
||||
auto loc = linalgOp.getLoc();
|
||||
ScopedContext scope(b, loc);
|
||||
auto *op = linalgOp.getOperation();
|
||||
if (auto copyOp = dyn_cast<CopyOp>(op)) {
|
||||
OperationFolder state;
|
||||
auto inputIvs = permuteIvs(parallelIvs, copyOp.inputPermutation(), state);
|
||||
auto outputIvs = permuteIvs(parallelIvs, copyOp.outputPermutation(), state);
|
||||
SmallVector<IndexHandle, 8> iivs(inputIvs.begin(), inputIvs.end());
|
||||
SmallVector<IndexHandle, 8> oivs(outputIvs.begin(), outputIvs.end());
|
||||
// clang-format off
|
||||
IndexedValue O(copyOp.getOutput(0)), I(copyOp.getInput(0));
|
||||
nLoops > 0 ?
|
||||
O(oivs) = I(iivs) :
|
||||
O() = I();
|
||||
// clang-format on
|
||||
return;
|
||||
}
|
||||
if (auto fillOp = dyn_cast<FillOp>(op)) {
|
||||
SmallVector<IndexHandle, 8> ivs(parallelIvs.begin(), parallelIvs.end());
|
||||
// clang-format off
|
||||
IndexedValue O(fillOp.getOutput(0));
|
||||
nLoops > 0 ?
|
||||
O(ivs) = ValueHandle(fillOp.getValue()) :
|
||||
O() = ValueHandle(fillOp.getValue());
|
||||
// clang-format on
|
||||
return;
|
||||
}
|
||||
if (auto dotOp = dyn_cast<DotOp>(op)) {
|
||||
IndexHandle r_i(reductionIvs[0]);
|
||||
IndexedValue A(dotOp.getInput(0)), B(dotOp.getInput(1)),
|
||||
C(dotOp.getOutput(0));
|
||||
C() = C() + A(r_i) * B(r_i);
|
||||
return;
|
||||
}
|
||||
if (auto matvecOp = dyn_cast<MatvecOp>(op)) {
|
||||
IndexHandle i(parallelIvs[0]), r_j(reductionIvs[0]);
|
||||
IndexedValue A(matvecOp.getInput(0)), B(matvecOp.getInput(1)),
|
||||
C(matvecOp.getOutput(0));
|
||||
C(i) = C(i) + A(i, r_j) * B(r_j);
|
||||
return;
|
||||
}
|
||||
if (auto matmulOp = dyn_cast<MatmulOp>(op)) {
|
||||
IndexHandle i(parallelIvs[0]), j(parallelIvs[1]), r_k(reductionIvs[0]);
|
||||
IndexedValue A(matmulOp.getInput(0)), B(matmulOp.getInput(1)),
|
||||
C(matmulOp.getOutput(0));
|
||||
C(i, j) = C(i, j) + A(i, r_k) * B(r_k, j);
|
||||
return;
|
||||
}
|
||||
if (auto convOp = dyn_cast<ConvOp>(op)) {
|
||||
auto maps = loopToOperandRangesMaps(op);
|
||||
SmallVector<ValueHandle, 8> fIdx(
|
||||
foldedAffineApplies(b, loc, maps[0], allIvs, folder));
|
||||
SmallVector<ValueHandle, 8> imIdx(
|
||||
foldedAffineApplies(b, loc, maps[1], allIvs, folder));
|
||||
SmallVector<ValueHandle, 8> oIdx(
|
||||
foldedAffineApplies(b, loc, maps[2], allIvs, folder));
|
||||
IndexedValue F(convOp.filter()), I(convOp.input()), O(convOp.output());
|
||||
O(oIdx) += F(fIdx) * I(imIdx);
|
||||
return;
|
||||
}
|
||||
if (auto genericOp = dyn_cast<GenericOp>(op)) {
|
||||
using edsc::intrinsics::detail::ValueHandleArray;
|
||||
unsigned nInputs = genericOp.getNumInputs();
|
||||
unsigned nOutputs = genericOp.getNumOutputs();
|
||||
SmallVector<Value *, 4> indexedValues(nInputs + nOutputs);
|
||||
// Emits the MLIR for the scalar part of the generic op by:
|
||||
// 1. Emitting linalg_load and linalg_store ops for each input and output
|
||||
// view in order. This is achieved by applying the appropriate input or
|
||||
// output map to the enclosing induction variables.
|
||||
// 2. Emitting a call to `op.fun()` that takes as arguments the scalars
|
||||
// from point 1. above.
|
||||
// 3. Emitting linalg_store to store the results of 2. to the output
|
||||
// views.
|
||||
//
|
||||
// An example output may resemble:
|
||||
//
|
||||
// ```
|
||||
// loop.for %i = %c0 to %0 step %c1 {
|
||||
// loop.for %j = %c0 to %1 step %c1 {
|
||||
// loop.for %k = %c0 to %4 step %c1 {
|
||||
// %11 = linalg.load %arg0[%i, %j] : !linalg.view<?x?xf32>
|
||||
// %12 = linalg.load %arg1[%i, %j, %k] : !linalg.view<?x?x?xf32>
|
||||
// %13 = linalg.load %arg2[%i, %k, %j] : !linalg.view<?x?x?xf32>
|
||||
// %14:2 = call @foo(%11, %12, %13) : (f32, f32, f32) -> (f32, f32)
|
||||
// linalg.store %14#0, %arg1[%i, %j, %k] : !linalg.view<?x?x?xf32>
|
||||
// linalg.store %14#1, %arg2[%i, %k, %j] : !linalg.view<?x?x?xf32>
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// ```
|
||||
|
||||
// 1.a. Emit linalg_load from input views.
|
||||
for (unsigned i = 0, e = nInputs; i < e; ++i) {
|
||||
ValueHandleArray indexing(foldedAffineApplies(
|
||||
b, loc, genericOp.getInputIndexingMap(i), allIvs, folder));
|
||||
indexedValues[i] = linalg_load(genericOp.getInput(i), indexing);
|
||||
}
|
||||
// 1.b. Emit linalg_load from output views..
|
||||
for (unsigned i = 0, e = nOutputs; i < e; ++i) {
|
||||
ValueHandleArray indexing(foldedAffineApplies(
|
||||
b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder));
|
||||
indexedValues[nInputs + i] =
|
||||
linalg_load(genericOp.getOutput(i), indexing);
|
||||
}
|
||||
// 2. Emit call.
|
||||
auto m = genericOp.getParentOfType<ModuleOp>();
|
||||
auto fun = m.lookupSymbol<FuncOp>(genericOp.fun());
|
||||
Operation *callOp = call(fun, indexedValues);
|
||||
assert(callOp->getNumResults() == genericOp.getNumOutputs());
|
||||
// 3. Emit linalg_store.
|
||||
for (unsigned i = 0, e = nOutputs; i < e; ++i) {
|
||||
ValueHandleArray indexing(foldedAffineApplies(
|
||||
b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder));
|
||||
linalg_store(callOp->getResult(i), genericOp.getOutput(i), indexing);
|
||||
}
|
||||
return;
|
||||
}
|
||||
llvm_unreachable("Missing emitScalarImplementation for op");
|
||||
}
|
||||
|
@ -15,6 +15,8 @@
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#include "mlir/AffineOps/AffineOps.h"
|
||||
#include "mlir/Dialect/LoopOps/LoopOps.h"
|
||||
#include "mlir/EDSC/Helpers.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
@ -22,17 +24,50 @@
|
||||
#include "mlir/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Linalg/IR/LinalgTypes.h"
|
||||
#include "mlir/Linalg/Passes.h"
|
||||
#include "mlir/Linalg/Utils/Intrinsics.h"
|
||||
#include "mlir/Linalg/Utils/Utils.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/StandardOps/Ops.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Support/STLExtras.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/FoldUtils.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::edsc;
|
||||
using namespace mlir::edsc::intrinsics;
|
||||
using namespace mlir::linalg;
|
||||
using namespace mlir::linalg::intrinsics;
|
||||
|
||||
using IndexedLinalgValue = TemplatedIndexedValue<linalg_load, linalg_store>;
|
||||
using edsc::op::operator+;
|
||||
using edsc::op::operator==;
|
||||
|
||||
static SmallVector<ValueHandle, 8>
|
||||
foldedAffineApplies(OpBuilder &b, Location loc, AffineMap map,
|
||||
ArrayRef<Value *> vals, OperationFolder &folder) {
|
||||
assert(map.getNumSymbols() == 0);
|
||||
assert(map.getNumInputs() == vals.size());
|
||||
SmallVector<ValueHandle, 8> res;
|
||||
res.reserve(map.getNumResults());
|
||||
auto dims = map.getNumDims();
|
||||
for (auto e : map.getResults()) {
|
||||
auto exprMap = AffineMap::get(dims, 0, e);
|
||||
SmallVector<Value *, 4> operands(vals.begin(), vals.end());
|
||||
canonicalizeMapAndOperands(&exprMap, &operands);
|
||||
res.push_back(affine_apply(folder, exprMap, operands));
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
static SmallVector<Value *, 4> permuteIvs(ArrayRef<Value *> ivs,
|
||||
Optional<AffineMap> permutation,
|
||||
OperationFolder &state) {
|
||||
return permutation ? applyMapToValues(ScopedContext::getBuilder(),
|
||||
ScopedContext::getLocation(),
|
||||
permutation.getValue(), ivs, state)
|
||||
: SmallVector<Value *, 4>(ivs.begin(), ivs.end());
|
||||
}
|
||||
|
||||
// Creates a number of ranges equal to the number of results in `map`.
|
||||
// The returned ranges correspond to the loop ranges, in the proper order, for
|
||||
@ -40,61 +75,272 @@ using namespace mlir::linalg;
|
||||
static SmallVector<Value *, 4> emitLoopRanges(OpBuilder &b, Location loc,
|
||||
AffineMap map,
|
||||
ArrayRef<Value *> allViewSizes,
|
||||
OperationFolder &state) {
|
||||
OperationFolder &folder) {
|
||||
// Apply `map` to get view sizes in loop order.
|
||||
auto sizes = applyMapToValues(b, loc, map, allViewSizes, state);
|
||||
auto sizes = applyMapToValues(b, loc, map, allViewSizes, folder);
|
||||
// Create a new range with the applied tile sizes.
|
||||
ScopedContext scope(b, loc);
|
||||
SmallVector<Value *, 4> res;
|
||||
for (unsigned idx = 0, e = map.getNumResults(); idx < e; ++idx) {
|
||||
res.push_back(b.create<RangeOp>(
|
||||
loc, state.create<ConstantIndexOp>(b, loc, 0), sizes[idx],
|
||||
state.create<ConstantIndexOp>(b, loc, 1)));
|
||||
res.push_back(range(constant_index(folder, 0), sizes[idx],
|
||||
constant_index(folder, 1)));
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
static void emitLinalgOpAsLoops(LinalgOp &linalgOp, OperationFolder &state) {
|
||||
OpBuilder b(linalgOp.getOperation());
|
||||
ScopedContext scope(b, linalgOp.getOperation()->getLoc());
|
||||
// The flattened loopToOperandRangesMaps is expected to be an invertible
|
||||
// permutation map (which is asserted in the inverse calculation).
|
||||
auto invertedMap =
|
||||
inversePermutation(concatAffineMaps(loopToOperandRangesMaps(linalgOp)));
|
||||
if (!invertedMap) {
|
||||
mlir::linalg::emitScalarImplementation({}, {}, {}, linalgOp, state);
|
||||
return;
|
||||
template <typename LinalgOpType> class LinalgScopedEmitter {};
|
||||
|
||||
template <> class LinalgScopedEmitter<CopyOp> {
|
||||
public:
|
||||
static void emitScalarImplementation(ArrayRef<Value *> allIvs, CopyOp copyOp,
|
||||
OperationFolder &folder) {
|
||||
auto nPar = copyOp.getNumParallelLoops();
|
||||
assert(nPar == allIvs.size());
|
||||
auto inputIvs =
|
||||
permuteIvs(allIvs.take_front(nPar), copyOp.inputPermutation(), folder);
|
||||
auto outputIvs =
|
||||
permuteIvs(allIvs.take_front(nPar), copyOp.outputPermutation(), folder);
|
||||
SmallVector<IndexHandle, 8> iivs(inputIvs.begin(), inputIvs.end());
|
||||
SmallVector<IndexHandle, 8> oivs(outputIvs.begin(), outputIvs.end());
|
||||
IndexedLinalgValue O(copyOp.getOutput(0)), I(copyOp.getInput(0));
|
||||
// Emit the proper scalar assignment, whether we are dealing with a 0-D or
|
||||
// an n-D loop nest; with or without permutations.
|
||||
// clang-format off
|
||||
nPar > 0 ? O(oivs) = I(iivs) :
|
||||
O() = I();
|
||||
// clang-format on
|
||||
}
|
||||
};
|
||||
|
||||
template <> class LinalgScopedEmitter<FillOp> {
|
||||
public:
|
||||
static void emitScalarImplementation(ArrayRef<Value *> allIvs, FillOp fillOp,
|
||||
OperationFolder &folder) {
|
||||
auto nPar = fillOp.getNumParallelLoops();
|
||||
assert(nPar == allIvs.size());
|
||||
auto ivs =
|
||||
SmallVector<IndexHandle, 4>(allIvs.begin(), allIvs.begin() + nPar);
|
||||
IndexedLinalgValue O(fillOp.getOutput(0));
|
||||
// Emit the proper scalar assignment, whether we are dealing with a 0-D or
|
||||
// an n-D loop nest; with or without permutations.
|
||||
nPar > 0 ? O(ivs) = ValueHandle(fillOp.getValue())
|
||||
: O() = ValueHandle(fillOp.getValue());
|
||||
}
|
||||
};
|
||||
|
||||
template <> class LinalgScopedEmitter<DotOp> {
|
||||
public:
|
||||
static void emitScalarImplementation(ArrayRef<Value *> allIvs, DotOp dotOp,
|
||||
OperationFolder &folder) {
|
||||
assert(allIvs.size() == 1);
|
||||
IndexHandle r_i(allIvs[0]);
|
||||
IndexedLinalgValue A(dotOp.getInput(0)), B(dotOp.getInput(1)),
|
||||
C(dotOp.getOutput(0));
|
||||
// Emit scalar form.
|
||||
C() = C() + A(r_i) * B(r_i);
|
||||
}
|
||||
};
|
||||
|
||||
template <> class LinalgScopedEmitter<MatvecOp> {
|
||||
public:
|
||||
static void emitScalarImplementation(ArrayRef<Value *> allIvs,
|
||||
MatvecOp matvecOp,
|
||||
OperationFolder &folder) {
|
||||
assert(allIvs.size() == 2);
|
||||
IndexHandle i(allIvs[0]), r_j(allIvs[1]);
|
||||
IndexedLinalgValue A(matvecOp.getInput(0)), B(matvecOp.getInput(1)),
|
||||
C(matvecOp.getOutput(0));
|
||||
// Emit scalar form.
|
||||
C(i) = C(i) + A(i, r_j) * B(r_j);
|
||||
}
|
||||
};
|
||||
|
||||
template <> class LinalgScopedEmitter<MatmulOp> {
|
||||
public:
|
||||
static void emitScalarImplementation(ArrayRef<Value *> allIvs,
|
||||
MatmulOp matmulOp,
|
||||
OperationFolder &folder) {
|
||||
assert(allIvs.size() == 3);
|
||||
IndexHandle i(allIvs[0]), j(allIvs[1]), r_k(allIvs[2]);
|
||||
IndexedLinalgValue A(matmulOp.getInput(0)), B(matmulOp.getInput(1)),
|
||||
C(matmulOp.getOutput(0));
|
||||
// Emit scalar form.
|
||||
C(i, j) = C(i, j) + A(i, r_k) * B(r_k, j);
|
||||
}
|
||||
};
|
||||
|
||||
template <> class LinalgScopedEmitter<ConvOp> {
|
||||
public:
|
||||
static void emitScalarImplementation(ArrayRef<Value *> allIvs, ConvOp convOp,
|
||||
OperationFolder &folder) {
|
||||
auto b = ScopedContext::getBuilder();
|
||||
auto loc = ScopedContext::getLocation();
|
||||
auto maps = loopToOperandRangesMaps(convOp);
|
||||
SmallVector<ValueHandle, 8> fIdx(
|
||||
foldedAffineApplies(b, loc, maps[0], allIvs, folder));
|
||||
SmallVector<ValueHandle, 8> imIdx(
|
||||
foldedAffineApplies(b, loc, maps[1], allIvs, folder));
|
||||
SmallVector<ValueHandle, 8> oIdx(
|
||||
foldedAffineApplies(b, loc, maps[2], allIvs, folder));
|
||||
IndexedLinalgValue F(convOp.filter()), I(convOp.input()),
|
||||
O(convOp.output());
|
||||
// Emit scalar form.
|
||||
O(oIdx) += F(fIdx) * I(imIdx);
|
||||
}
|
||||
};
|
||||
|
||||
// Emits the MLIR for the scalar part of the generic op by:
|
||||
// 1. Emitting linalg_load and linalg_store ops for each input and output
|
||||
// view in order. This is achieved by applying the appropriate input or
|
||||
// output map to the enclosing induction variables.
|
||||
// 2. Emitting a call to `op.fun()` that takes as arguments the scalars
|
||||
// from point 1. above.
|
||||
// 3. Emitting linalg_store to store the results of 2. to the output
|
||||
// views.
|
||||
//
|
||||
// An example output may resemble:
|
||||
//
|
||||
// ```
|
||||
// loop.for %i = %c0 to %0 step %c1 {
|
||||
// loop.for %j = %c0 to %1 step %c1 {
|
||||
// loop.for %k = %c0 to %4 step %c1 {
|
||||
// %11 = linalg.load %arg0[%i, %j] : !linalg.view<?x?xf32>
|
||||
// %12 = linalg.load %arg1[%i, %j, %k] : !linalg.view<?x?x?xf32>
|
||||
// %13 = linalg.load %arg2[%i, %k, %j] : !linalg.view<?x?x?xf32>
|
||||
// %14:2 = call @foo(%11, %12, %13) : (f32, f32, f32) -> (f32, f32)
|
||||
// linalg.store %14#0, %arg1[%i, %j, %k] : !linalg.view<?x?x?xf32>
|
||||
// linalg.store %14#1, %arg2[%i, %k, %j] : !linalg.view<?x?x?xf32>
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// ```
|
||||
template <> class LinalgScopedEmitter<GenericOp> {
|
||||
public:
|
||||
static void emitScalarImplementation(ArrayRef<Value *> allIvs,
|
||||
GenericOp genericOp,
|
||||
OperationFolder &folder) {
|
||||
auto b = ScopedContext::getBuilder();
|
||||
auto loc = ScopedContext::getLocation();
|
||||
using edsc::intrinsics::detail::ValueHandleArray;
|
||||
unsigned nInputs = genericOp.getNumInputs();
|
||||
unsigned nOutputs = genericOp.getNumOutputs();
|
||||
SmallVector<Value *, 4> indexedValues(nInputs + nOutputs);
|
||||
|
||||
// 1.a. Emit linalg_load from input views.
|
||||
for (unsigned i = 0, e = nInputs; i < e; ++i) {
|
||||
ValueHandleArray indexing(foldedAffineApplies(
|
||||
b, loc, genericOp.getInputIndexingMap(i), allIvs, folder));
|
||||
indexedValues[i] = linalg_load(genericOp.getInput(i), indexing);
|
||||
}
|
||||
|
||||
// 1.b. Emit linalg_load from output views.
|
||||
for (unsigned i = 0, e = nOutputs; i < e; ++i) {
|
||||
ValueHandleArray indexing(foldedAffineApplies(
|
||||
b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder));
|
||||
indexedValues[nInputs + i] =
|
||||
linalg_load(genericOp.getOutput(i), indexing);
|
||||
}
|
||||
|
||||
// 2. Emit call.
|
||||
auto m = genericOp.getParentOfType<ModuleOp>();
|
||||
auto fun = m.lookupSymbol<FuncOp>(genericOp.fun());
|
||||
Operation *callOp = call(fun, indexedValues);
|
||||
assert(callOp->getNumResults() == genericOp.getNumOutputs());
|
||||
|
||||
// 3. Emit linalg_store.
|
||||
for (unsigned i = 0, e = nOutputs; i < e; ++i) {
|
||||
ValueHandleArray indexing(foldedAffineApplies(
|
||||
b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder));
|
||||
linalg_store(callOp->getResult(i), genericOp.getOutput(i), indexing);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ConcreteOp>
|
||||
class LinalgRewritePattern : public RewritePattern {
|
||||
public:
|
||||
explicit LinalgRewritePattern(MLIRContext *context)
|
||||
: RewritePattern(ConcreteOp::getOperationName(), /*benefit=*/1, context) {
|
||||
}
|
||||
|
||||
auto loopRanges = emitLoopRanges(scope.getBuilder(), scope.getLocation(),
|
||||
invertedMap, getViewSizes(linalgOp), state);
|
||||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
OpBuilder b(op);
|
||||
ScopedContext scope(b, op->getLoc());
|
||||
|
||||
SmallVector<IndexHandle, 4> parallelIvs(linalgOp.getNumParallelLoops());
|
||||
SmallVector<IndexHandle, 4> reductionIvs(linalgOp.getNumReductionLoops());
|
||||
SmallVector<IndexHandle, 4> windowIvs(linalgOp.getNumWindowLoops());
|
||||
auto pivs = IndexHandle::makeIndexHandlePointers(parallelIvs);
|
||||
auto rivs = IndexHandle::makeIndexHandlePointers(reductionIvs);
|
||||
auto wivs = IndexHandle::makeIndexHandlePointers(windowIvs);
|
||||
assert(loopRanges.size() == pivs.size() + rivs.size() + wivs.size());
|
||||
// The flattened loopToOperandRangesMaps is expected to be an invertible
|
||||
// permutation map (which is asserted in the inverse calculation).
|
||||
auto linalgOp = cast<ConcreteOp>(op);
|
||||
auto invertedMap =
|
||||
inversePermutation(concatAffineMaps(loopToOperandRangesMaps(linalgOp)));
|
||||
if (!invertedMap) {
|
||||
LinalgScopedEmitter<ConcreteOp>::emitScalarImplementation({}, linalgOp,
|
||||
folder);
|
||||
rewriter.replaceOp(op, {});
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
ArrayRef<Value *> ranges(loopRanges);
|
||||
LoopNestRangeBuilder(pivs, ranges.take_front(pivs.size()))([&] {
|
||||
LoopNestRangeBuilder(
|
||||
rivs, ranges.drop_back(wivs.size()).take_back(rivs.size()))([&] {
|
||||
LoopNestRangeBuilder(wivs, ranges.take_back(wivs.size()))(
|
||||
[&linalgOp, ¶llelIvs, &reductionIvs, &windowIvs, &state] {
|
||||
SmallVector<mlir::Value *, 4> parallel(
|
||||
parallelIvs.begin(), parallelIvs.end());
|
||||
SmallVector<mlir::Value *, 4> reduction(
|
||||
reductionIvs.begin(), reductionIvs.end());
|
||||
SmallVector<mlir::Value *, 4> window(
|
||||
windowIvs.begin(), windowIvs.end());
|
||||
mlir::linalg::emitScalarImplementation(
|
||||
parallel, reduction, window, linalgOp, state);
|
||||
auto nPar = linalgOp.getNumParallelLoops();
|
||||
auto nRed = linalgOp.getNumReductionLoops();
|
||||
auto nWin = linalgOp.getNumWindowLoops();
|
||||
SmallVector<IndexHandle, 4> allIvs(nPar + nRed + nWin);
|
||||
SmallVector<ValueHandle *, 4> allPIvs = makeIndexHandlePointers(allIvs);
|
||||
auto pivs = MutableArrayRef<ValueHandle *>(allPIvs).take_front(nPar);
|
||||
auto rivs = MutableArrayRef<ValueHandle *>(allPIvs)
|
||||
.take_front(nPar + nRed)
|
||||
.take_back(nRed);
|
||||
auto wivs = MutableArrayRef<ValueHandle *>(allPIvs).take_back(nWin);
|
||||
|
||||
auto loopRanges =
|
||||
emitLoopRanges(scope.getBuilder(), scope.getLocation(), invertedMap,
|
||||
getViewSizes(linalgOp), folder);
|
||||
assert(loopRanges.size() == pivs.size() + rivs.size() + wivs.size());
|
||||
|
||||
// clang-format off
|
||||
ArrayRef<Value *> ranges(loopRanges);
|
||||
LoopNestRangeBuilder(pivs, ranges.take_front(nPar))([&] {
|
||||
LoopNestRangeBuilder(rivs, ranges.drop_back(nWin).take_back(nRed))([&] {
|
||||
LoopNestRangeBuilder(wivs, ranges.take_back(wivs.size()))(
|
||||
[&linalgOp, &allIvs, this] {
|
||||
auto allIvValues = extractValues(allIvs);
|
||||
LinalgScopedEmitter<ConcreteOp>::emitScalarImplementation(
|
||||
allIvValues, linalgOp, folder);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
// clang-format on
|
||||
// clang-format on
|
||||
rewriter.replaceOp(op, {});
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
mutable OperationFolder folder;
|
||||
};
|
||||
|
||||
// Helper classes for type list expansion.
|
||||
template <typename... LinalgOps> class ConversionList;
|
||||
|
||||
template <> class ConversionList<> {
|
||||
public:
|
||||
static void build(OwningRewritePatternList &patterns, MLIRContext *ctx) {}
|
||||
};
|
||||
|
||||
template <typename ConcreteOp, typename... LinalgOps>
|
||||
class ConversionList<ConcreteOp, LinalgOps...> {
|
||||
public:
|
||||
static void build(OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
||||
patterns.insert<LinalgRewritePattern<ConcreteOp>>(ctx);
|
||||
ConversionList<LinalgOps...>::build(patterns, ctx);
|
||||
}
|
||||
};
|
||||
|
||||
/// Populate the given list with patterns that convert from Linalg to LLVM.
|
||||
static void
|
||||
populateLinalgToLoopRewritePatterns(OwningRewritePatternList &patterns,
|
||||
MLIRContext *ctx) {
|
||||
ConversionList<
|
||||
#define GET_OP_LIST
|
||||
#include "mlir/Linalg/IR/LinalgLibraryOps.cpp.inc"
|
||||
>::build(patterns, ctx);
|
||||
}
|
||||
|
||||
namespace {
|
||||
@ -104,11 +350,17 @@ struct LowerLinalgToLoopsPass : public FunctionPass<LowerLinalgToLoopsPass> {
|
||||
} // namespace
|
||||
|
||||
void LowerLinalgToLoopsPass::runOnFunction() {
|
||||
OperationFolder state;
|
||||
getFunction().walk<LinalgOp>([&state](LinalgOp linalgOp) {
|
||||
emitLinalgOpAsLoops(linalgOp, state);
|
||||
linalgOp.getOperation()->erase();
|
||||
});
|
||||
OwningRewritePatternList patterns;
|
||||
populateLinalgToLoopRewritePatterns(patterns, &getContext());
|
||||
|
||||
ConversionTarget target(getContext());
|
||||
target.addLegalDialect<AffineOpsDialect>();
|
||||
target.addLegalDialect<loop::LoopOpsDialect>();
|
||||
target.addLegalDialect<StandardOpsDialect>();
|
||||
if (failed(
|
||||
applyPartialConversion(getFunction(), target, std::move(patterns)))) {
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
||||
FunctionPassBase *mlir::linalg::createLowerLinalgToLoopsPass() {
|
||||
|
@ -381,7 +381,7 @@ mlir::linalg::tileLinalgOp(LinalgOp op, ArrayRef<Value *> tileSizes,
|
||||
// 3. Create the tiled loops.
|
||||
LinalgOp res = op;
|
||||
SmallVector<IndexHandle, 4> ivs(loopRanges.size());
|
||||
auto pivs = IndexHandle::makeIndexHandlePointers(ivs);
|
||||
auto pivs = makeIndexHandlePointers(ivs);
|
||||
LoopNestRangeBuilder(pivs, loopRanges)([&] {
|
||||
auto b = ScopedContext::getBuilder();
|
||||
auto loc = ScopedContext::getLocation();
|
||||
|
10
third_party/mlir/lib/Linalg/Utils/Utils.cpp
vendored
10
third_party/mlir/lib/Linalg/Utils/Utils.cpp
vendored
@ -106,16 +106,6 @@ ValueHandle LoopNestRangeBuilder::LoopNestRangeBuilder::operator()(
|
||||
return ValueHandle::null();
|
||||
}
|
||||
|
||||
SmallVector<Value *, 8> mlir::linalg::getViewSizes(LinalgOp &linalgOp) {
|
||||
SmallVector<Value *, 8> res;
|
||||
for (auto v : linalgOp.getInputsAndOutputs()) {
|
||||
ViewType t = v->getType().cast<ViewType>();
|
||||
for (unsigned i = 0; i < t.getRank(); ++i)
|
||||
res.push_back(linalg::intrinsics::dim(v, i));
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
static Value *emitOrFoldComposedAffineApply(OpBuilder &b, Location loc,
|
||||
AffineMap map,
|
||||
ArrayRef<Value *> operandsRef,
|
||||
|
@ -273,10 +273,9 @@ VectorTransferRewriter<VectorTransferReadOp>::matchAndRewrite(
|
||||
IndexedValue remote(transfer.getMemRef());
|
||||
MemRefView view(transfer.getMemRef());
|
||||
VectorView vectorView(transfer.getVector());
|
||||
SmallVector<IndexHandle, 8> ivs =
|
||||
IndexHandle::makeIndexHandles(vectorView.rank());
|
||||
SmallVector<IndexHandle, 8> ivs = makeIndexHandles(vectorView.rank());
|
||||
SmallVector<ValueHandle *, 8> pivs =
|
||||
IndexHandle::makeIndexHandlePointers(ivs);
|
||||
makeIndexHandlePointers(MutableArrayRef<IndexHandle>(ivs));
|
||||
coalesceCopy(transfer, &pivs, &vectorView);
|
||||
|
||||
auto lbs = vectorView.getLbs();
|
||||
@ -335,10 +334,8 @@ VectorTransferRewriter<VectorTransferWriteOp>::matchAndRewrite(
|
||||
MemRefView view(transfer.getMemRef());
|
||||
ValueHandle vectorValue(transfer.getVector());
|
||||
VectorView vectorView(transfer.getVector());
|
||||
SmallVector<IndexHandle, 8> ivs =
|
||||
IndexHandle::makeIndexHandles(vectorView.rank());
|
||||
SmallVector<ValueHandle *, 8> pivs =
|
||||
IndexHandle::makeIndexHandlePointers(ivs);
|
||||
SmallVector<IndexHandle, 8> ivs = makeIndexHandles(vectorView.rank());
|
||||
SmallVector<ValueHandle *, 8> pivs = makeIndexHandlePointers(ivs);
|
||||
coalesceCopy(transfer, &pivs, &vectorView);
|
||||
|
||||
auto lbs = vectorView.getLbs();
|
||||
|
Loading…
x
Reference in New Issue
Block a user