Adds support for unrolling single-result vector operations with iterator type lists and indexing maps to a target vector size.
Adds unit tests for unrolling the vector ContractionOp with different iteration orders. PiperOrigin-RevId: 283747503 Change-Id: Ib7e4f757d15760cd89fc09fedc49b7a2dc2a1fe6
This commit is contained in:
parent
222977dffd
commit
8f661bace2
third_party/mlir
include/mlir/Dialect/VectorOps
lib/Dialect/VectorOps
@ -157,6 +157,18 @@ def Vector_ContractionOp :
|
||||
static StringRef getParallelIteratorTypeName() {
|
||||
return "parallel";
|
||||
}
|
||||
|
||||
// Returns the bounds of each dimension in the iteration space spanned
|
||||
// by the iterator types of this operation.
|
||||
void getIterationBounds(SmallVectorImpl<int64_t> &iterationBounds);
|
||||
|
||||
// Returns a list of index maps, where there is a list entry for each
|
||||
// op indexing map attribute (i.e. one for each input and output, with
|
||||
// the output listed last). Each index map, maps from this operations
|
||||
// iteration space, to vector dimensions of the maps input/output.
|
||||
void getIterationIndexMap(
|
||||
std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap);
|
||||
|
||||
std::vector<std::pair<int64_t, int64_t>> getContractingDimMap();
|
||||
std::vector<std::pair<int64_t, int64_t>> getBatchDimMap();
|
||||
}];
|
||||
|
@ -40,4 +40,9 @@ def : Pat<(AddFOp:$op_results $a, $b),
|
||||
(UnrollVectorOp<[2, 2]> $op_results, $a, $b),
|
||||
[(Constraint<HasShape<[4, 4]>> $a)]>;
|
||||
|
||||
// TODO(andydavis) Add Constraints on lhs/rhs shapes.
|
||||
def : Pat<(Vector_ContractionOp:$op_results $a, $b, $c, $masks, $attr0, $attr1),
|
||||
(UnrollVectorOp<[2, 2, 2]> $op_results, $a, $b, $c),
|
||||
[(Constraint<HasShape<[4, 4]>> $c)]>;
|
||||
|
||||
#endif // VECTOR_TRANSFORMS
|
||||
|
@ -271,6 +271,44 @@ getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
|
||||
return dimMap;
|
||||
}
|
||||
|
||||
void ContractionOp::getIterationBounds(
|
||||
SmallVectorImpl<int64_t> &iterationBounds) {
|
||||
auto lhsShape = getLhsType().getShape();
|
||||
auto resShape = getResultType().getShape();
|
||||
SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps());
|
||||
SmallVector<int64_t, 2> iterationShape;
|
||||
for (auto it : llvm::enumerate(iterator_types())) {
|
||||
// Search lhs/rhs map results for 'targetExpr'.
|
||||
auto targetExpr = getAffineDimExpr(it.index(), getContext());
|
||||
auto iteratorTypeName = it.value().cast<StringAttr>().getValue();
|
||||
if (iteratorTypeName == getReductionIteratorTypeName()) {
|
||||
// Get reduction dim size from lhs shape (same size in rhsShape).
|
||||
int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr);
|
||||
assert(lhsDimIndex >= 0);
|
||||
iterationBounds.push_back(lhsShape[lhsDimIndex]);
|
||||
continue;
|
||||
}
|
||||
// Get parallel dimension size from result shape.
|
||||
int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr);
|
||||
assert(resDimIndex >= 0);
|
||||
iterationBounds.push_back(resShape[resDimIndex]);
|
||||
}
|
||||
}
|
||||
|
||||
void ContractionOp::getIterationIndexMap(
|
||||
std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) {
|
||||
unsigned numMaps = indexing_maps().getValue().size();
|
||||
iterationIndexMap.resize(numMaps);
|
||||
for (auto it : llvm::enumerate(indexing_maps())) {
|
||||
auto index = it.index();
|
||||
auto map = it.value().cast<AffineMapAttr>().getValue();
|
||||
for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
|
||||
auto dim = map.getResult(i).cast<AffineDimExpr>();
|
||||
iterationIndexMap[index][dim.getPosition()] = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
|
||||
SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps());
|
||||
return getDimMap(indexingMaps, iterator_types(),
|
||||
|
@ -77,6 +77,15 @@ static int64_t computeMaxLinearIndex(ArrayRef<int64_t> basis) {
|
||||
return res;
|
||||
}
|
||||
|
||||
/// Computes and returns the linearized index of 'offsets' w.r.t. 'basis'.
|
||||
static int64_t linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis) {
|
||||
assert(offsets.size() == basis.size());
|
||||
int64_t linearIndex = 0;
|
||||
for (unsigned idx = 0, e = basis.size(); idx < e; ++idx)
|
||||
linearIndex += offsets[idx] * basis[idx];
|
||||
return linearIndex;
|
||||
}
|
||||
|
||||
/// Given a shape with sizes greater than 0 along all dimensions, returns the
|
||||
/// delinearized components of linearIndex along shape.
|
||||
static SmallVector<int64_t, 8> delinearize(int64_t linearIndex,
|
||||
@ -151,9 +160,9 @@ static Operation *cloneOpWithOperandsAndTypes(PatternRewriter &builder,
|
||||
Location loc, Operation *op,
|
||||
ArrayRef<Value *> operands,
|
||||
ArrayRef<Type> resultTypes) {
|
||||
OperationState *res = new OperationState(loc, op->getName().getStringRef(),
|
||||
operands, resultTypes, {});
|
||||
return builder.createOperation(*res);
|
||||
OperationState res(loc, op->getName().getStringRef(), operands, resultTypes,
|
||||
op->getAttrs());
|
||||
return builder.createOperation(res);
|
||||
}
|
||||
|
||||
// Helper function for Tablegen.
|
||||
@ -164,6 +173,223 @@ static bool hasShape(Value *v, ArrayRef<int64_t> shape) {
|
||||
return std::equal(t.getShape().begin(), t.getShape().end(), shape.begin());
|
||||
}
|
||||
|
||||
static Value *makeSplatZero(Location loc, PatternRewriter &rewriter,
|
||||
VectorType vt) {
|
||||
auto t = vt.getElementType();
|
||||
Value *f = nullptr;
|
||||
if (t.isBF16() || t.isF16())
|
||||
f = rewriter.create<ConstantOp>(loc, t, rewriter.getF64FloatAttr(0.0f));
|
||||
else if (t.isF32())
|
||||
f = rewriter.create<ConstantOp>(loc, t, rewriter.getF32FloatAttr(0.0f));
|
||||
else if (t.isF64())
|
||||
f = rewriter.create<ConstantOp>(loc, t, rewriter.getF64FloatAttr(0.0f));
|
||||
if (f)
|
||||
return rewriter.create<SplatOp>(loc, vt, f);
|
||||
llvm_unreachable("Unsupported type in `makeSplatZero`");
|
||||
}
|
||||
|
||||
// Populates 'resultElements[indexMap[i]]' with elements from 'inputElements[i]'
|
||||
// for each index 'i' in inputElements with a valid mapping in 'indexMap'.
|
||||
static void getMappedElements(const DenseMap<int64_t, int64_t> &indexMap,
|
||||
ArrayRef<int64_t> inputElements,
|
||||
SmallVectorImpl<int64_t> &resultElements) {
|
||||
assert(indexMap.size() == resultElements.size());
|
||||
assert(inputElements.size() >= resultElements.size());
|
||||
for (unsigned i = 0, e = inputElements.size(); i < e; ++i) {
|
||||
auto it = indexMap.find(i);
|
||||
if (it != indexMap.end())
|
||||
resultElements[it->second] = inputElements[i];
|
||||
}
|
||||
}
|
||||
|
||||
// UnrolledOperandState aggregates per-operand state required for op unrolling.
|
||||
struct UnrolledOperandState {
|
||||
Value *operand;
|
||||
SmallVector<int64_t, 4> unrolledShape;
|
||||
SmallVector<int64_t, 4> unrollFactors;
|
||||
SmallVector<int64_t, 8> basis;
|
||||
int64_t numInstances;
|
||||
};
|
||||
|
||||
// Populates 'state' with unrolled shape, unroll factors, basis and
|
||||
// num unrolled instances for 'operand'.
|
||||
static void getUnrolledOperandState(Value *operand,
|
||||
const DenseMap<int64_t, int64_t> &indexMap,
|
||||
ArrayRef<int64_t> targetShape,
|
||||
UnrolledOperandState &state) {
|
||||
auto vectorType = operand->getType().cast<VectorType>();
|
||||
state.operand = operand;
|
||||
// Compute unrolled shape of 'operand'.
|
||||
state.unrolledShape.resize(vectorType.getRank());
|
||||
getMappedElements(indexMap, targetShape, state.unrolledShape);
|
||||
// Compute unroll factors for unrolled shape.
|
||||
auto maybeUnrollFactors =
|
||||
shapeRatio(vectorType.getShape(), state.unrolledShape);
|
||||
assert(maybeUnrollFactors.hasValue());
|
||||
state.unrollFactors = *maybeUnrollFactors;
|
||||
// Compute 'basis' and 'numInstances' based on 'state.unrollFactors'.
|
||||
state.basis = computeStrides(state.unrollFactors);
|
||||
state.numInstances = computeMaxLinearIndex(state.unrollFactors);
|
||||
}
|
||||
|
||||
// Computes and returns the linear index of the unrolled vector at
|
||||
// 'vectorOffsets' within the vector operand represented by 'state'.
|
||||
static int64_t
|
||||
getUnrolledOperandLinearIndex(UnrolledOperandState &state,
|
||||
ArrayRef<int64_t> vectorOffsets,
|
||||
DenseMap<int64_t, int64_t> &indexMap) {
|
||||
// Compute operand offsets.
|
||||
SmallVector<int64_t, 4> sliceOffsets(state.unrolledShape.size());
|
||||
getMappedElements(indexMap, vectorOffsets, sliceOffsets);
|
||||
// Compute and return linear index of 'sliceOffsets' w.r.t 'state.basis'.
|
||||
return linearize(sliceOffsets, state.basis);
|
||||
}
|
||||
|
||||
// Returns an unrolled vector at 'vectorOffsets' within the vector operand
|
||||
// represented by 'state'. The value is created if not present in 'cache'.
|
||||
static Value *getOrCreateUnrolledOperandSlice(
|
||||
Location loc, UnrolledOperandState &state, ArrayRef<int64_t> vectorOffsets,
|
||||
ArrayRef<int64_t> offsets, DenseMap<int64_t, int64_t> &indexMap,
|
||||
SmallVectorImpl<Value *> &cache, PatternRewriter &builder) {
|
||||
// Compute operand offsets.
|
||||
SmallVector<int64_t, 4> sliceOffsets(state.unrolledShape.size());
|
||||
getMappedElements(indexMap, offsets, sliceOffsets);
|
||||
// TODO(b/144845578) Support non-1 strides.
|
||||
SmallVector<int64_t, 4> sliceStrides(state.unrolledShape.size(), 1);
|
||||
// Compute linear index of 'sliceOffsets' w.r.t 'state.basis'.
|
||||
int64_t sliceLinearIndex =
|
||||
getUnrolledOperandLinearIndex(state, vectorOffsets, indexMap);
|
||||
assert(sliceLinearIndex < static_cast<int64_t>(cache.size()));
|
||||
auto *operandSlice = cache[sliceLinearIndex];
|
||||
if (operandSlice == nullptr) {
|
||||
// Initialize 'cache' with slice from 'state.operand'.
|
||||
operandSlice = builder.create<vector::StridedSliceOp>(
|
||||
loc, state.operand, sliceOffsets, state.unrolledShape, sliceStrides);
|
||||
// Store value back to 'cache'.
|
||||
cache[sliceLinearIndex] = operandSlice;
|
||||
}
|
||||
return operandSlice;
|
||||
}
|
||||
|
||||
//
|
||||
// unrollSingleResultStructuredOp
|
||||
//
|
||||
// Returns a value representing the result of structured operation 'op'
|
||||
// with iteration bounds 'iterationBounds' unrolled to 'targetShape'.
|
||||
// An iteration space index map argument 'iterationIndexMapList' must be
|
||||
// specified, with a map for each structured op input and a single map for the
|
||||
// single result. The last map in the list must be the single result map.
|
||||
// Extra operands can be passed to unrolled instances of 'op' using the
|
||||
// 'extraOperands' argument.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// // Before unrolling
|
||||
//
|
||||
// operand0 operand1 operand2
|
||||
// \ | /
|
||||
// -------------------- opA --------------------
|
||||
//
|
||||
// // After unrolling by 2
|
||||
//
|
||||
// operand0 operand1 operand2
|
||||
// / \ / \ / \
|
||||
// slice00 slice01 slice10 slice11 slice20 slice21
|
||||
// \ | | | / |
|
||||
// -------------------- opA0 -------------------- |
|
||||
// | | | |
|
||||
// \ | | /
|
||||
// -------------------- opA1 -------------------
|
||||
// | |
|
||||
// \ /
|
||||
// insertslice
|
||||
// |
|
||||
|
||||
// TODO(andydavis) Generalize this to support structured ops beyond
|
||||
// vector ContractionOp, and merge it with 'unrollSingleResultOpMatchingType'
|
||||
static Value *unrollSingleResultStructuredOp(
|
||||
Operation *op, ArrayRef<int64_t> iterationBounds,
|
||||
std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMapList,
|
||||
ArrayRef<int64_t> targetShape, ArrayRef<Value *> extraOperands,
|
||||
PatternRewriter &builder) {
|
||||
auto shapedType = op->getResult(0)->getType().dyn_cast_or_null<ShapedType>();
|
||||
if (!shapedType || !shapedType.hasStaticShape())
|
||||
assert(false && "Expected a statically shaped result type");
|
||||
|
||||
// Compute unroll factors for 'iterationBounds' based on 'targetShape'
|
||||
auto maybeUnrollFactors = shapeRatio(iterationBounds, targetShape);
|
||||
if (!maybeUnrollFactors.hasValue())
|
||||
assert(false && "Failed to compute unroll factors for target shape");
|
||||
auto unrollFactors = *maybeUnrollFactors;
|
||||
|
||||
// Compute unrolled operation state for each mapped operand.
|
||||
unsigned numMaps = iterationIndexMapList.size();
|
||||
SmallVector<UnrolledOperandState, 3> unrolledOperandState(numMaps);
|
||||
assert(op->getNumOperands() >= numMaps);
|
||||
for (unsigned i = 0; i < numMaps; ++i) {
|
||||
getUnrolledOperandState(op->getOperand(i), iterationIndexMapList[i],
|
||||
targetShape, unrolledOperandState[i]);
|
||||
}
|
||||
// Compute number of total unrolled instances.
|
||||
auto numUnrolledInstances = computeMaxLinearIndex(unrollFactors);
|
||||
auto basis = computeStrides(unrollFactors);
|
||||
|
||||
auto &resultOperandState = unrolledOperandState[numMaps - 1];
|
||||
auto unrolledResultType = VectorType::get(resultOperandState.unrolledShape,
|
||||
shapedType.getElementType());
|
||||
|
||||
// Initialize caches for intermediate vector results.
|
||||
std::vector<SmallVector<Value *, 4>> caches(numMaps);
|
||||
for (unsigned i = 0; i < numMaps; ++i) {
|
||||
caches[i].resize(unrolledOperandState[i].numInstances);
|
||||
}
|
||||
|
||||
// Unroll 'numUnrolledInstances' of 'op', storing results in 'caches'.
|
||||
for (unsigned i = 0; i < numUnrolledInstances; ++i) {
|
||||
// De-linearize w.r.t. 'basis'.
|
||||
auto vectorOffsets = delinearize(i, basis);
|
||||
// Convert from unrolled vector-space offsets to element-space offsets.
|
||||
auto offsets = zipMap([](int64_t v1, int64_t v2) { return v1 * v2; },
|
||||
vectorOffsets, targetShape);
|
||||
// Get cached slice (or create slice) for each operand at 'offsets'.
|
||||
SmallVector<Value *, 3> operands;
|
||||
operands.reserve(numMaps);
|
||||
for (unsigned i = 0; i < numMaps; ++i) {
|
||||
operands.push_back(getOrCreateUnrolledOperandSlice(
|
||||
op->getLoc(), unrolledOperandState[i], vectorOffsets, offsets,
|
||||
iterationIndexMapList[i], caches[i], builder));
|
||||
}
|
||||
// Create op on sliced vector arguments.
|
||||
operands.append(extraOperands.begin(), extraOperands.end());
|
||||
auto resultVector =
|
||||
cloneOpWithOperandsAndTypes(builder, op->getLoc(), op, operands,
|
||||
unrolledResultType)
|
||||
->getResult(0);
|
||||
|
||||
// Compute linear result index.
|
||||
int64_t resultIndex = getUnrolledOperandLinearIndex(
|
||||
resultOperandState, vectorOffsets, iterationIndexMapList[numMaps - 1]);
|
||||
// Update result cache at 'resultIndex'.
|
||||
caches[numMaps - 1][resultIndex] = resultVector;
|
||||
}
|
||||
|
||||
// Make zero splat into which we will insert results from 'cache[numMaps - 1]'
|
||||
auto resultVectorType = op->getResult(0)->getType().cast<VectorType>();
|
||||
auto *res = makeSplatZero(op->getLoc(), builder, resultVectorType);
|
||||
SmallVector<int64_t, 4> strides(resultOperandState.unrollFactors.size(), 1);
|
||||
// Insert vector accumulators into output.
|
||||
for (unsigned i = 0; i < resultOperandState.numInstances; ++i) {
|
||||
auto vectorOffsets = delinearize(i, resultOperandState.basis);
|
||||
// Convert from unrolled vector-space offsets to element-space offsets.
|
||||
auto offsets = zipMap([](int64_t v1, int64_t v2) { return v1 * v2; },
|
||||
vectorOffsets, resultOperandState.unrolledShape);
|
||||
res = builder.create<vector::InsertStridedSliceOp>(
|
||||
op->getLoc(), caches[numMaps - 1][i], res, offsets, strides);
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
// Entry point for unrolling declarative pattern rewrites.
|
||||
// `op` is unrolled to the `targetShape` as follows, for each of its operands:
|
||||
// 1. the unrolled type `unrolledVectorType` and number of unrolled instances
|
||||
@ -200,6 +426,26 @@ static bool hasShape(Value *v, ArrayRef<int64_t> shape) {
|
||||
Value * mlir::vector::unrollSingleResultOpMatchingType(PatternRewriter &builder,
|
||||
Operation *op,
|
||||
ArrayRef<int64_t> targetShape) {
|
||||
if (auto contractionOp = dyn_cast<vector::ContractionOp>(op)) {
|
||||
// Get contraction op iteration bounds.
|
||||
SmallVector<int64_t, 6> iterationBounds;
|
||||
contractionOp.getIterationBounds(iterationBounds);
|
||||
assert(iterationBounds.size() == targetShape.size());
|
||||
// Get map from iteration space index to lhs/rhs/result shape index.
|
||||
std::vector<DenseMap<int64_t, int64_t>> iterationIndexMapList;
|
||||
contractionOp.getIterationIndexMap(iterationIndexMapList);
|
||||
// TODO(andydavis) Support unrollable vector masks.
|
||||
SmallVector<Value *, 2> masks(contractionOp.masks().begin(),
|
||||
contractionOp.masks().end());
|
||||
// Unroll 'op' 'iterationBounds' to 'targetShape'.
|
||||
return unrollSingleResultStructuredOp(op, iterationBounds,
|
||||
iterationIndexMapList, targetShape,
|
||||
masks, builder);
|
||||
}
|
||||
// TODO(andydavis) Create trivial iteration bounds and index map for
|
||||
// elementwise operations and call 'unrollSingleResultStructuredOp'. Remove
|
||||
// fakefork/join if possible.
|
||||
|
||||
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
|
||||
"]: unrollSingleResultOpMatchingType on func:\n");
|
||||
LLVM_DEBUG(op->getParentOfType<FuncOp>().print(dbgs()));
|
||||
@ -365,24 +611,6 @@ struct ConvertFakeForkFromBlockArgsOp : public RewritePattern {
|
||||
}
|
||||
};
|
||||
|
||||
static Value *makeSplatZero(Location loc, PatternRewriter &rewriter,
|
||||
VectorType vt) {
|
||||
auto t = vt.getElementType();
|
||||
Value *f = nullptr;
|
||||
if (t.isBF16() || t.isF16())
|
||||
f = rewriter.create<ConstantOp>(loc, t, rewriter.getF16FloatAttr(0.0f))
|
||||
.getResult();
|
||||
else if (t.isF32())
|
||||
f = rewriter.create<ConstantOp>(loc, t, rewriter.getF32FloatAttr(0.0f))
|
||||
.getResult();
|
||||
else if (t.isF64())
|
||||
f = rewriter.create<ConstantOp>(loc, t, rewriter.getF64FloatAttr(0.0f))
|
||||
.getResult();
|
||||
if (f)
|
||||
return rewriter.create<SplatOp>(loc, vt, f).getResult();
|
||||
llvm_unreachable("Unsupported type in `makeSplatZero`");
|
||||
}
|
||||
|
||||
// Rewrites a fakeJoin, whose (unique) operand is a blockArgument, into multiple
|
||||
// vector.strided_slice ops.
|
||||
struct ConvertFakeJoinOp : public RewritePattern {
|
||||
|
Loading…
Reference in New Issue
Block a user