Adds support for unrolling single-result vector operations with iterator type lists and indexing maps to a target vector size.

Adds unit tests for unrolling the vector ContractionOp with different iteration orders.

PiperOrigin-RevId: 283747503
Change-Id: Ib7e4f757d15760cd89fc09fedc49b7a2dc2a1fe6
This commit is contained in:
A. Unique TensorFlower 2019-12-04 06:53:07 -08:00 committed by TensorFlower Gardener
parent 222977dffd
commit 8f661bace2
4 changed files with 304 additions and 21 deletions
third_party/mlir
include/mlir/Dialect/VectorOps
lib/Dialect/VectorOps

View File

@ -157,6 +157,18 @@ def Vector_ContractionOp :
static StringRef getParallelIteratorTypeName() {
return "parallel";
}
// Returns the bounds of each dimension in the iteration space spanned
// by the iterator types of this operation.
void getIterationBounds(SmallVectorImpl<int64_t> &iterationBounds);
// Returns a list of index maps, where there is a list entry for each
// op indexing map attribute (i.e. one for each input and output, with
// the output listed last). Each index map, maps from this operations
// iteration space, to vector dimensions of the maps input/output.
void getIterationIndexMap(
std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap);
std::vector<std::pair<int64_t, int64_t>> getContractingDimMap();
std::vector<std::pair<int64_t, int64_t>> getBatchDimMap();
}];

View File

@ -40,4 +40,9 @@ def : Pat<(AddFOp:$op_results $a, $b),
(UnrollVectorOp<[2, 2]> $op_results, $a, $b),
[(Constraint<HasShape<[4, 4]>> $a)]>;
// TODO(andydavis) Add Constraints on lhs/rhs shapes.
def : Pat<(Vector_ContractionOp:$op_results $a, $b, $c, $masks, $attr0, $attr1),
(UnrollVectorOp<[2, 2, 2]> $op_results, $a, $b, $c),
[(Constraint<HasShape<[4, 4]>> $c)]>;
#endif // VECTOR_TRANSFORMS

View File

@ -271,6 +271,44 @@ getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
return dimMap;
}
void ContractionOp::getIterationBounds(
SmallVectorImpl<int64_t> &iterationBounds) {
auto lhsShape = getLhsType().getShape();
auto resShape = getResultType().getShape();
SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps());
SmallVector<int64_t, 2> iterationShape;
for (auto it : llvm::enumerate(iterator_types())) {
// Search lhs/rhs map results for 'targetExpr'.
auto targetExpr = getAffineDimExpr(it.index(), getContext());
auto iteratorTypeName = it.value().cast<StringAttr>().getValue();
if (iteratorTypeName == getReductionIteratorTypeName()) {
// Get reduction dim size from lhs shape (same size in rhsShape).
int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr);
assert(lhsDimIndex >= 0);
iterationBounds.push_back(lhsShape[lhsDimIndex]);
continue;
}
// Get parallel dimension size from result shape.
int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr);
assert(resDimIndex >= 0);
iterationBounds.push_back(resShape[resDimIndex]);
}
}
void ContractionOp::getIterationIndexMap(
std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) {
unsigned numMaps = indexing_maps().getValue().size();
iterationIndexMap.resize(numMaps);
for (auto it : llvm::enumerate(indexing_maps())) {
auto index = it.index();
auto map = it.value().cast<AffineMapAttr>().getValue();
for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
auto dim = map.getResult(i).cast<AffineDimExpr>();
iterationIndexMap[index][dim.getPosition()] = i;
}
}
}
std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps());
return getDimMap(indexingMaps, iterator_types(),

View File

@ -77,6 +77,15 @@ static int64_t computeMaxLinearIndex(ArrayRef<int64_t> basis) {
return res;
}
/// Computes and returns the linearized index of 'offsets' w.r.t. 'basis'.
static int64_t linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis) {
assert(offsets.size() == basis.size());
int64_t linearIndex = 0;
for (unsigned idx = 0, e = basis.size(); idx < e; ++idx)
linearIndex += offsets[idx] * basis[idx];
return linearIndex;
}
/// Given a shape with sizes greater than 0 along all dimensions, returns the
/// delinearized components of linearIndex along shape.
static SmallVector<int64_t, 8> delinearize(int64_t linearIndex,
@ -151,9 +160,9 @@ static Operation *cloneOpWithOperandsAndTypes(PatternRewriter &builder,
Location loc, Operation *op,
ArrayRef<Value *> operands,
ArrayRef<Type> resultTypes) {
OperationState *res = new OperationState(loc, op->getName().getStringRef(),
operands, resultTypes, {});
return builder.createOperation(*res);
OperationState res(loc, op->getName().getStringRef(), operands, resultTypes,
op->getAttrs());
return builder.createOperation(res);
}
// Helper function for Tablegen.
@ -164,6 +173,223 @@ static bool hasShape(Value *v, ArrayRef<int64_t> shape) {
return std::equal(t.getShape().begin(), t.getShape().end(), shape.begin());
}
static Value *makeSplatZero(Location loc, PatternRewriter &rewriter,
VectorType vt) {
auto t = vt.getElementType();
Value *f = nullptr;
if (t.isBF16() || t.isF16())
f = rewriter.create<ConstantOp>(loc, t, rewriter.getF64FloatAttr(0.0f));
else if (t.isF32())
f = rewriter.create<ConstantOp>(loc, t, rewriter.getF32FloatAttr(0.0f));
else if (t.isF64())
f = rewriter.create<ConstantOp>(loc, t, rewriter.getF64FloatAttr(0.0f));
if (f)
return rewriter.create<SplatOp>(loc, vt, f);
llvm_unreachable("Unsupported type in `makeSplatZero`");
}
// Populates 'resultElements[indexMap[i]]' with elements from 'inputElements[i]'
// for each index 'i' in inputElements with a valid mapping in 'indexMap'.
static void getMappedElements(const DenseMap<int64_t, int64_t> &indexMap,
ArrayRef<int64_t> inputElements,
SmallVectorImpl<int64_t> &resultElements) {
assert(indexMap.size() == resultElements.size());
assert(inputElements.size() >= resultElements.size());
for (unsigned i = 0, e = inputElements.size(); i < e; ++i) {
auto it = indexMap.find(i);
if (it != indexMap.end())
resultElements[it->second] = inputElements[i];
}
}
// UnrolledOperandState aggregates per-operand state required for op unrolling.
struct UnrolledOperandState {
Value *operand;
SmallVector<int64_t, 4> unrolledShape;
SmallVector<int64_t, 4> unrollFactors;
SmallVector<int64_t, 8> basis;
int64_t numInstances;
};
// Populates 'state' with unrolled shape, unroll factors, basis and
// num unrolled instances for 'operand'.
static void getUnrolledOperandState(Value *operand,
const DenseMap<int64_t, int64_t> &indexMap,
ArrayRef<int64_t> targetShape,
UnrolledOperandState &state) {
auto vectorType = operand->getType().cast<VectorType>();
state.operand = operand;
// Compute unrolled shape of 'operand'.
state.unrolledShape.resize(vectorType.getRank());
getMappedElements(indexMap, targetShape, state.unrolledShape);
// Compute unroll factors for unrolled shape.
auto maybeUnrollFactors =
shapeRatio(vectorType.getShape(), state.unrolledShape);
assert(maybeUnrollFactors.hasValue());
state.unrollFactors = *maybeUnrollFactors;
// Compute 'basis' and 'numInstances' based on 'state.unrollFactors'.
state.basis = computeStrides(state.unrollFactors);
state.numInstances = computeMaxLinearIndex(state.unrollFactors);
}
// Computes and returns the linear index of the unrolled vector at
// 'vectorOffsets' within the vector operand represented by 'state'.
static int64_t
getUnrolledOperandLinearIndex(UnrolledOperandState &state,
ArrayRef<int64_t> vectorOffsets,
DenseMap<int64_t, int64_t> &indexMap) {
// Compute operand offsets.
SmallVector<int64_t, 4> sliceOffsets(state.unrolledShape.size());
getMappedElements(indexMap, vectorOffsets, sliceOffsets);
// Compute and return linear index of 'sliceOffsets' w.r.t 'state.basis'.
return linearize(sliceOffsets, state.basis);
}
// Returns an unrolled vector at 'vectorOffsets' within the vector operand
// represented by 'state'. The value is created if not present in 'cache'.
static Value *getOrCreateUnrolledOperandSlice(
Location loc, UnrolledOperandState &state, ArrayRef<int64_t> vectorOffsets,
ArrayRef<int64_t> offsets, DenseMap<int64_t, int64_t> &indexMap,
SmallVectorImpl<Value *> &cache, PatternRewriter &builder) {
// Compute operand offsets.
SmallVector<int64_t, 4> sliceOffsets(state.unrolledShape.size());
getMappedElements(indexMap, offsets, sliceOffsets);
// TODO(b/144845578) Support non-1 strides.
SmallVector<int64_t, 4> sliceStrides(state.unrolledShape.size(), 1);
// Compute linear index of 'sliceOffsets' w.r.t 'state.basis'.
int64_t sliceLinearIndex =
getUnrolledOperandLinearIndex(state, vectorOffsets, indexMap);
assert(sliceLinearIndex < static_cast<int64_t>(cache.size()));
auto *operandSlice = cache[sliceLinearIndex];
if (operandSlice == nullptr) {
// Initialize 'cache' with slice from 'state.operand'.
operandSlice = builder.create<vector::StridedSliceOp>(
loc, state.operand, sliceOffsets, state.unrolledShape, sliceStrides);
// Store value back to 'cache'.
cache[sliceLinearIndex] = operandSlice;
}
return operandSlice;
}
//
// unrollSingleResultStructuredOp
//
// Returns a value representing the result of structured operation 'op'
// with iteration bounds 'iterationBounds' unrolled to 'targetShape'.
// An iteration space index map argument 'iterationIndexMapList' must be
// specified, with a map for each structured op input and a single map for the
// single result. The last map in the list must be the single result map.
// Extra operands can be passed to unrolled instances of 'op' using the
// 'extraOperands' argument.
//
// Example:
//
// // Before unrolling
//
// operand0 operand1 operand2
// \ | /
// -------------------- opA --------------------
//
// // After unrolling by 2
//
// operand0 operand1 operand2
// / \ / \ / \
// slice00 slice01 slice10 slice11 slice20 slice21
// \ | | | / |
// -------------------- opA0 -------------------- |
// | | | |
// \ | | /
// -------------------- opA1 -------------------
// | |
// \ /
// insertslice
// |
// TODO(andydavis) Generalize this to support structured ops beyond
// vector ContractionOp, and merge it with 'unrollSingleResultOpMatchingType'
static Value *unrollSingleResultStructuredOp(
Operation *op, ArrayRef<int64_t> iterationBounds,
std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMapList,
ArrayRef<int64_t> targetShape, ArrayRef<Value *> extraOperands,
PatternRewriter &builder) {
auto shapedType = op->getResult(0)->getType().dyn_cast_or_null<ShapedType>();
if (!shapedType || !shapedType.hasStaticShape())
assert(false && "Expected a statically shaped result type");
// Compute unroll factors for 'iterationBounds' based on 'targetShape'
auto maybeUnrollFactors = shapeRatio(iterationBounds, targetShape);
if (!maybeUnrollFactors.hasValue())
assert(false && "Failed to compute unroll factors for target shape");
auto unrollFactors = *maybeUnrollFactors;
// Compute unrolled operation state for each mapped operand.
unsigned numMaps = iterationIndexMapList.size();
SmallVector<UnrolledOperandState, 3> unrolledOperandState(numMaps);
assert(op->getNumOperands() >= numMaps);
for (unsigned i = 0; i < numMaps; ++i) {
getUnrolledOperandState(op->getOperand(i), iterationIndexMapList[i],
targetShape, unrolledOperandState[i]);
}
// Compute number of total unrolled instances.
auto numUnrolledInstances = computeMaxLinearIndex(unrollFactors);
auto basis = computeStrides(unrollFactors);
auto &resultOperandState = unrolledOperandState[numMaps - 1];
auto unrolledResultType = VectorType::get(resultOperandState.unrolledShape,
shapedType.getElementType());
// Initialize caches for intermediate vector results.
std::vector<SmallVector<Value *, 4>> caches(numMaps);
for (unsigned i = 0; i < numMaps; ++i) {
caches[i].resize(unrolledOperandState[i].numInstances);
}
// Unroll 'numUnrolledInstances' of 'op', storing results in 'caches'.
for (unsigned i = 0; i < numUnrolledInstances; ++i) {
// De-linearize w.r.t. 'basis'.
auto vectorOffsets = delinearize(i, basis);
// Convert from unrolled vector-space offsets to element-space offsets.
auto offsets = zipMap([](int64_t v1, int64_t v2) { return v1 * v2; },
vectorOffsets, targetShape);
// Get cached slice (or create slice) for each operand at 'offsets'.
SmallVector<Value *, 3> operands;
operands.reserve(numMaps);
for (unsigned i = 0; i < numMaps; ++i) {
operands.push_back(getOrCreateUnrolledOperandSlice(
op->getLoc(), unrolledOperandState[i], vectorOffsets, offsets,
iterationIndexMapList[i], caches[i], builder));
}
// Create op on sliced vector arguments.
operands.append(extraOperands.begin(), extraOperands.end());
auto resultVector =
cloneOpWithOperandsAndTypes(builder, op->getLoc(), op, operands,
unrolledResultType)
->getResult(0);
// Compute linear result index.
int64_t resultIndex = getUnrolledOperandLinearIndex(
resultOperandState, vectorOffsets, iterationIndexMapList[numMaps - 1]);
// Update result cache at 'resultIndex'.
caches[numMaps - 1][resultIndex] = resultVector;
}
// Make zero splat into which we will insert results from 'cache[numMaps - 1]'
auto resultVectorType = op->getResult(0)->getType().cast<VectorType>();
auto *res = makeSplatZero(op->getLoc(), builder, resultVectorType);
SmallVector<int64_t, 4> strides(resultOperandState.unrollFactors.size(), 1);
// Insert vector accumulators into output.
for (unsigned i = 0; i < resultOperandState.numInstances; ++i) {
auto vectorOffsets = delinearize(i, resultOperandState.basis);
// Convert from unrolled vector-space offsets to element-space offsets.
auto offsets = zipMap([](int64_t v1, int64_t v2) { return v1 * v2; },
vectorOffsets, resultOperandState.unrolledShape);
res = builder.create<vector::InsertStridedSliceOp>(
op->getLoc(), caches[numMaps - 1][i], res, offsets, strides);
}
return res;
}
// Entry point for unrolling declarative pattern rewrites.
// `op` is unrolled to the `targetShape` as follows, for each of its operands:
// 1. the unrolled type `unrolledVectorType` and number of unrolled instances
@ -200,6 +426,26 @@ static bool hasShape(Value *v, ArrayRef<int64_t> shape) {
Value * mlir::vector::unrollSingleResultOpMatchingType(PatternRewriter &builder,
Operation *op,
ArrayRef<int64_t> targetShape) {
if (auto contractionOp = dyn_cast<vector::ContractionOp>(op)) {
// Get contraction op iteration bounds.
SmallVector<int64_t, 6> iterationBounds;
contractionOp.getIterationBounds(iterationBounds);
assert(iterationBounds.size() == targetShape.size());
// Get map from iteration space index to lhs/rhs/result shape index.
std::vector<DenseMap<int64_t, int64_t>> iterationIndexMapList;
contractionOp.getIterationIndexMap(iterationIndexMapList);
// TODO(andydavis) Support unrollable vector masks.
SmallVector<Value *, 2> masks(contractionOp.masks().begin(),
contractionOp.masks().end());
// Unroll 'op' 'iterationBounds' to 'targetShape'.
return unrollSingleResultStructuredOp(op, iterationBounds,
iterationIndexMapList, targetShape,
masks, builder);
}
// TODO(andydavis) Create trivial iteration bounds and index map for
// elementwise operations and call 'unrollSingleResultStructuredOp'. Remove
// fakefork/join if possible.
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
"]: unrollSingleResultOpMatchingType on func:\n");
LLVM_DEBUG(op->getParentOfType<FuncOp>().print(dbgs()));
@ -365,24 +611,6 @@ struct ConvertFakeForkFromBlockArgsOp : public RewritePattern {
}
};
static Value *makeSplatZero(Location loc, PatternRewriter &rewriter,
VectorType vt) {
auto t = vt.getElementType();
Value *f = nullptr;
if (t.isBF16() || t.isF16())
f = rewriter.create<ConstantOp>(loc, t, rewriter.getF16FloatAttr(0.0f))
.getResult();
else if (t.isF32())
f = rewriter.create<ConstantOp>(loc, t, rewriter.getF32FloatAttr(0.0f))
.getResult();
else if (t.isF64())
f = rewriter.create<ConstantOp>(loc, t, rewriter.getF64FloatAttr(0.0f))
.getResult();
if (f)
return rewriter.create<SplatOp>(loc, vt, f).getResult();
llvm_unreachable("Unsupported type in `makeSplatZero`");
}
// Rewrites a fakeJoin, whose (unique) operand is a blockArgument, into multiple
// vector.strided_slice ops.
struct ConvertFakeJoinOp : public RewritePattern {