Unify vector op unrolling transformation.
Unifies vector op unrolling transformation, by using the same unrolling implementation for contraction and elementwise operations. Removes fakefork/join operations which are non longer needed now that we have the InsertStridedSlice operation. PiperOrigin-RevId: 284570784 Change-Id: Ic3412bc6456d91bc24afda25bbe98888ff4d5849
This commit is contained in:
parent
27a98ccab4
commit
900bf75eb5
@ -165,6 +165,7 @@ def Vector_ContractionOp :
|
||||
static StringRef getParallelIteratorTypeName() {
|
||||
return "parallel";
|
||||
}
|
||||
static unsigned getAccOperandIndex() { return 2; }
|
||||
|
||||
// Returns the bounds of each dimension in the iteration space spanned
|
||||
// by the iterator types of this operation.
|
||||
|
@ -102,58 +102,6 @@ static SmallVector<int64_t, 8> delinearize(int64_t linearIndex,
|
||||
return res;
|
||||
}
|
||||
|
||||
static constexpr auto kFakeForkOp = "__fake_fork__";
|
||||
static constexpr auto kFakeJoinOp = "__fake_join__";
|
||||
static constexpr auto kUnrollAttrName = "__unroll__";
|
||||
static constexpr auto kBaseCoordAttrName = "__base_coord__";
|
||||
|
||||
// Reads the IntegerArray attribute named `kUnrollAttrName` from `op` and
|
||||
// returns its representation as a vector of integers.
|
||||
static SmallVector<int64_t, 8> extractUnrollFactors(Operation *op) {
|
||||
SmallVector<int64_t, 8> res;
|
||||
auto unrollAttr = op->getAttr(kUnrollAttrName);
|
||||
if (!unrollAttr)
|
||||
return res;
|
||||
auto unrollArrayAttr = unrollAttr.cast<ArrayAttr>();
|
||||
res.reserve(unrollArrayAttr.size());
|
||||
for (auto attr : unrollArrayAttr) {
|
||||
auto unroll = attr.cast<IntegerAttr>().getValue().getSExtValue();
|
||||
assert(unroll > 0);
|
||||
res.push_back(unroll);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
// Creates a custom `kFakeForkOp` used in progressive lowering to other vector
|
||||
// operations.
|
||||
static Operation *createFakeForkOp(PatternRewriter &builder, Location loc,
|
||||
Value *operand, ArrayRef<Type> resultTypes,
|
||||
ArrayRef<int64_t> unrollFactors = {}) {
|
||||
OperationState *forkOp =
|
||||
new OperationState(loc, kFakeForkOp, operand, resultTypes, {});
|
||||
if (!unrollFactors.empty())
|
||||
forkOp->addAttribute(kUnrollAttrName,
|
||||
builder.getI64ArrayAttr(unrollFactors));
|
||||
return builder.createOperation(*forkOp);
|
||||
}
|
||||
|
||||
// Creates a custom `kFakeJoinOp` used in progressive lowering to other vector
|
||||
// operations.
|
||||
static Operation *createFakeJoinOp(PatternRewriter &builder, Location loc,
|
||||
ArrayRef<Value *> operands, Type resultType,
|
||||
ArrayRef<int64_t> unrollFactors = {},
|
||||
ArrayRef<int64_t> baseCoords = {}) {
|
||||
OperationState *joinOp =
|
||||
new OperationState(loc, kFakeJoinOp, operands, resultType, {});
|
||||
if (!unrollFactors.empty())
|
||||
joinOp->addAttribute(kUnrollAttrName,
|
||||
builder.getI64ArrayAttr(unrollFactors));
|
||||
if (!baseCoords.empty())
|
||||
joinOp->addAttribute(kBaseCoordAttrName,
|
||||
builder.getI64ArrayAttr(baseCoords));
|
||||
return builder.createOperation(*joinOp);
|
||||
}
|
||||
|
||||
// Clones `op` into a new operations that takes `operands` and returns
|
||||
// `resultTypes`.
|
||||
static Operation *cloneOpWithOperandsAndTypes(PatternRewriter &builder,
|
||||
@ -202,9 +150,9 @@ static void getMappedElements(const DenseMap<int64_t, int64_t> &indexMap,
|
||||
}
|
||||
}
|
||||
|
||||
// UnrolledOperandState aggregates per-operand state required for op unrolling.
|
||||
struct UnrolledOperandState {
|
||||
Value *operand;
|
||||
// UnrolledVectorState aggregates per-operand/result vector state required for
|
||||
// unrolling.
|
||||
struct UnrolledVectorState {
|
||||
SmallVector<int64_t, 4> unrolledShape;
|
||||
SmallVector<int64_t, 4> unrollFactors;
|
||||
SmallVector<int64_t, 8> basis;
|
||||
@ -212,14 +160,12 @@ struct UnrolledOperandState {
|
||||
};
|
||||
|
||||
// Populates 'state' with unrolled shape, unroll factors, basis and
|
||||
// num unrolled instances for 'operand'.
|
||||
static void getUnrolledOperandState(Value *operand,
|
||||
// num unrolled instances for 'vectorType'.
|
||||
static void initUnrolledVectorState(VectorType vectorType,
|
||||
const DenseMap<int64_t, int64_t> &indexMap,
|
||||
ArrayRef<int64_t> targetShape,
|
||||
UnrolledOperandState &state) {
|
||||
auto vectorType = operand->getType().cast<VectorType>();
|
||||
state.operand = operand;
|
||||
// Compute unrolled shape of 'operand'.
|
||||
UnrolledVectorState &state) {
|
||||
// Compute unrolled shape of 'vectorType'.
|
||||
state.unrolledShape.resize(vectorType.getRank());
|
||||
getMappedElements(indexMap, targetShape, state.unrolledShape);
|
||||
// Compute unroll factors for unrolled shape.
|
||||
@ -233,53 +179,72 @@ static void getUnrolledOperandState(Value *operand,
|
||||
}
|
||||
|
||||
// Computes and returns the linear index of the unrolled vector at
|
||||
// 'vectorOffsets' within the vector operand represented by 'state'.
|
||||
// 'vectorOffsets' within the vector represented by 'state'.
|
||||
static int64_t
|
||||
getUnrolledOperandLinearIndex(UnrolledOperandState &state,
|
||||
ArrayRef<int64_t> vectorOffsets,
|
||||
DenseMap<int64_t, int64_t> &indexMap) {
|
||||
// Compute operand offsets.
|
||||
getUnrolledVectorLinearIndex(UnrolledVectorState &state,
|
||||
ArrayRef<int64_t> vectorOffsets,
|
||||
DenseMap<int64_t, int64_t> &indexMap) {
|
||||
// Compute vector offsets.
|
||||
SmallVector<int64_t, 4> sliceOffsets(state.unrolledShape.size());
|
||||
getMappedElements(indexMap, vectorOffsets, sliceOffsets);
|
||||
// Compute and return linear index of 'sliceOffsets' w.r.t 'state.basis'.
|
||||
return linearize(sliceOffsets, state.basis);
|
||||
}
|
||||
|
||||
// Returns an unrolled vector at 'vectorOffsets' within the vector operand
|
||||
// represented by 'state'. The value is created if not present in 'cache'.
|
||||
static Value *getOrCreateUnrolledOperandSlice(
|
||||
Location loc, UnrolledOperandState &state, ArrayRef<int64_t> vectorOffsets,
|
||||
// Returns an unrolled vector at 'vectorOffsets' within the vector
|
||||
// represented by 'state'. The vector is created from a slice of 'initValue'
|
||||
// if not present in 'cache'.
|
||||
static Value *getOrCreateUnrolledVectorSlice(
|
||||
Location loc, UnrolledVectorState &state, ArrayRef<int64_t> vectorOffsets,
|
||||
ArrayRef<int64_t> offsets, DenseMap<int64_t, int64_t> &indexMap,
|
||||
SmallVectorImpl<Value *> &cache, PatternRewriter &builder) {
|
||||
// Compute operand offsets.
|
||||
Value *initValue, SmallVectorImpl<Value *> &cache,
|
||||
PatternRewriter &builder) {
|
||||
// Compute slice offsets.
|
||||
SmallVector<int64_t, 4> sliceOffsets(state.unrolledShape.size());
|
||||
getMappedElements(indexMap, offsets, sliceOffsets);
|
||||
// TODO(b/144845578) Support non-1 strides.
|
||||
SmallVector<int64_t, 4> sliceStrides(state.unrolledShape.size(), 1);
|
||||
// Compute linear index of 'sliceOffsets' w.r.t 'state.basis'.
|
||||
int64_t sliceLinearIndex =
|
||||
getUnrolledOperandLinearIndex(state, vectorOffsets, indexMap);
|
||||
getUnrolledVectorLinearIndex(state, vectorOffsets, indexMap);
|
||||
assert(sliceLinearIndex < static_cast<int64_t>(cache.size()));
|
||||
auto *operandSlice = cache[sliceLinearIndex];
|
||||
if (operandSlice == nullptr) {
|
||||
// Initialize 'cache' with slice from 'state.operand'.
|
||||
operandSlice = builder.create<vector::StridedSliceOp>(
|
||||
loc, state.operand, sliceOffsets, state.unrolledShape, sliceStrides);
|
||||
auto *valueSlice = cache[sliceLinearIndex];
|
||||
if (valueSlice == nullptr) {
|
||||
assert(initValue != nullptr);
|
||||
// Initialize 'cache' with slice from 'state.value'.
|
||||
valueSlice = builder.create<vector::StridedSliceOp>(
|
||||
loc, initValue, sliceOffsets, state.unrolledShape, sliceStrides);
|
||||
// Store value back to 'cache'.
|
||||
cache[sliceLinearIndex] = operandSlice;
|
||||
cache[sliceLinearIndex] = valueSlice;
|
||||
}
|
||||
return operandSlice;
|
||||
return valueSlice;
|
||||
}
|
||||
|
||||
// VectorState aggregates per-operand/result vector state required for
|
||||
// creating slices of vector operands, and clones of the operation being
|
||||
// unrolled.
|
||||
struct VectorState {
|
||||
// The type of this vector.
|
||||
VectorType type;
|
||||
// Map from iteration space index to vector dimension index.
|
||||
DenseMap<int64_t, int64_t> indexMap;
|
||||
// Index of this value in operation's operand list (-1 if not an operand).
|
||||
int64_t operandIndex = -1;
|
||||
// Accumulator iterator flag.
|
||||
bool isAcc = false;
|
||||
};
|
||||
|
||||
//
|
||||
// unrollSingleResultStructuredOp
|
||||
//
|
||||
// Returns a value representing the result of structured operation 'op'
|
||||
// with iteration bounds 'iterationBounds' unrolled to 'targetShape'.
|
||||
// An iteration space index map argument 'iterationIndexMapList' must be
|
||||
// specified, with a map for each structured op input and a single map for the
|
||||
// single result. The map at index 'indexMapListResultIndex' in the list must
|
||||
// be the single result map.
|
||||
// A list of VectorState objects must be specified in 'vectors', where
|
||||
// each VectorState in the list represents a vector operand or vector result
|
||||
// (if the operation does not have an accumulator operand).
|
||||
// The VectorState at index 'resultIndex' in the list must be the state
|
||||
// associated with the operations single result (i.e. either its accumulator
|
||||
// operand or vector result value).
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
@ -304,13 +269,24 @@ static Value *getOrCreateUnrolledOperandSlice(
|
||||
// insertslice
|
||||
// |
|
||||
|
||||
// TODO(andydavis) Add the following canonicalization/simplifcation patterns:
|
||||
// *) Add pattern which matches InsertStridedSlice -> StridedSlice and forwards
|
||||
// InsertStridedSlice operand to StridedSlice.
|
||||
// *) Add pattern which matches SourceOp -> StridedSlice -> UserOp which checks
|
||||
// if there are duplicate identical StridedSlice ops from SourceOp, and
|
||||
// rewrites itself to use the first duplicate. This transformation should
|
||||
// cause users of identifical StridedSlice ops to reuse the same StridedSlice
|
||||
// operation, and leave the duplicate StridedSlice ops with no users
|
||||
// (removable with DCE).
|
||||
|
||||
// TODO(andydavis) Generalize this to support structured ops beyond
|
||||
// vector ContractionOp, and merge it with 'unrollSingleResultOpMatchingType'
|
||||
static Value *unrollSingleResultStructuredOp(
|
||||
Operation *op, ArrayRef<int64_t> iterationBounds,
|
||||
std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMapList,
|
||||
unsigned indexMapListResultIndex, ArrayRef<int64_t> targetShape,
|
||||
PatternRewriter &builder) {
|
||||
static Value *unrollSingleResultStructuredOp(Operation *op,
|
||||
ArrayRef<int64_t> iterationBounds,
|
||||
std::vector<VectorState> &vectors,
|
||||
unsigned resultIndex,
|
||||
ArrayRef<int64_t> targetShape,
|
||||
PatternRewriter &builder) {
|
||||
auto shapedType = op->getResult(0)->getType().dyn_cast_or_null<ShapedType>();
|
||||
if (!shapedType || !shapedType.hasStaticShape())
|
||||
assert(false && "Expected a statically shaped result type");
|
||||
@ -321,27 +297,25 @@ static Value *unrollSingleResultStructuredOp(
|
||||
assert(false && "Failed to compute unroll factors for target shape");
|
||||
auto unrollFactors = *maybeUnrollFactors;
|
||||
|
||||
// Compute unrolled operation state for each mapped operand.
|
||||
unsigned numMaps = iterationIndexMapList.size();
|
||||
SmallVector<UnrolledOperandState, 3> unrolledOperandState(numMaps);
|
||||
assert(op->getNumOperands() >= numMaps);
|
||||
for (unsigned i = 0; i < numMaps; ++i) {
|
||||
getUnrolledOperandState(op->getOperand(i), iterationIndexMapList[i],
|
||||
targetShape, unrolledOperandState[i]);
|
||||
// Compute unrolled vector state for each vector in 'vectors'.
|
||||
unsigned numVectors = vectors.size();
|
||||
SmallVector<UnrolledVectorState, 3> unrolledVectorState(numVectors);
|
||||
for (unsigned i = 0; i < numVectors; ++i) {
|
||||
initUnrolledVectorState(vectors[i].type, vectors[i].indexMap, targetShape,
|
||||
unrolledVectorState[i]);
|
||||
}
|
||||
// Compute number of total unrolled instances.
|
||||
auto numUnrolledInstances = computeMaxLinearIndex(unrollFactors);
|
||||
auto basis = computeStrides(unrollFactors);
|
||||
|
||||
auto &resultOperandState = unrolledOperandState[indexMapListResultIndex];
|
||||
auto unrolledResultType = VectorType::get(resultOperandState.unrolledShape,
|
||||
auto &resultValueState = unrolledVectorState[resultIndex];
|
||||
auto unrolledResultType = VectorType::get(resultValueState.unrolledShape,
|
||||
shapedType.getElementType());
|
||||
|
||||
// Initialize caches for intermediate vector results.
|
||||
std::vector<SmallVector<Value *, 4>> caches(numMaps);
|
||||
for (unsigned i = 0; i < numMaps; ++i) {
|
||||
caches[i].resize(unrolledOperandState[i].numInstances);
|
||||
}
|
||||
std::vector<SmallVector<Value *, 4>> caches(numVectors);
|
||||
for (unsigned i = 0; i < numVectors; ++i)
|
||||
caches[i].resize(unrolledVectorState[i].numInstances);
|
||||
|
||||
// Unroll 'numUnrolledInstances' of 'op', storing results in 'caches'.
|
||||
for (unsigned i = 0; i < numUnrolledInstances; ++i) {
|
||||
@ -352,11 +326,15 @@ static Value *unrollSingleResultStructuredOp(
|
||||
vectorOffsets, targetShape);
|
||||
// Get cached slice (or create slice) for each operand at 'offsets'.
|
||||
SmallVector<Value *, 3> operands;
|
||||
operands.reserve(numMaps);
|
||||
for (unsigned i = 0; i < numMaps; ++i) {
|
||||
operands.push_back(getOrCreateUnrolledOperandSlice(
|
||||
op->getLoc(), unrolledOperandState[i], vectorOffsets, offsets,
|
||||
iterationIndexMapList[i], caches[i], builder));
|
||||
operands.resize(op->getNumOperands());
|
||||
for (unsigned i = 0; i < numVectors; ++i) {
|
||||
int64_t operandIndex = vectors[i].operandIndex;
|
||||
if (operandIndex < 0)
|
||||
continue; // Output
|
||||
auto *operand = op->getOperand(operandIndex);
|
||||
operands[operandIndex] = getOrCreateUnrolledVectorSlice(
|
||||
op->getLoc(), unrolledVectorState[i], vectorOffsets, offsets,
|
||||
vectors[i].indexMap, operand, caches[i], builder);
|
||||
}
|
||||
// Create op on sliced vector arguments.
|
||||
auto resultVector =
|
||||
@ -365,146 +343,117 @@ static Value *unrollSingleResultStructuredOp(
|
||||
->getResult(0);
|
||||
|
||||
// Compute linear result index.
|
||||
int64_t resultIndex = getUnrolledOperandLinearIndex(
|
||||
resultOperandState, vectorOffsets,
|
||||
iterationIndexMapList[indexMapListResultIndex]);
|
||||
// Update result cache at 'resultIndex'.
|
||||
caches[indexMapListResultIndex][resultIndex] = resultVector;
|
||||
int64_t linearIndex = getUnrolledVectorLinearIndex(
|
||||
resultValueState, vectorOffsets, vectors[resultIndex].indexMap);
|
||||
// Update result cache at 'linearIndex'.
|
||||
caches[resultIndex][linearIndex] = resultVector;
|
||||
}
|
||||
|
||||
// Make zero splat into which we will insert results from
|
||||
// 'cache[indexMapListResultIndex]'
|
||||
// 'cache[resultIndex]'
|
||||
auto resultVectorType = op->getResult(0)->getType().cast<VectorType>();
|
||||
auto *res = makeSplatZero(op->getLoc(), builder, resultVectorType);
|
||||
SmallVector<int64_t, 4> strides(resultOperandState.unrollFactors.size(), 1);
|
||||
SmallVector<int64_t, 4> strides(resultValueState.unrollFactors.size(), 1);
|
||||
// Insert vector accumulators into output.
|
||||
for (unsigned i = 0; i < resultOperandState.numInstances; ++i) {
|
||||
auto vectorOffsets = delinearize(i, resultOperandState.basis);
|
||||
for (unsigned i = 0; i < resultValueState.numInstances; ++i) {
|
||||
auto vectorOffsets = delinearize(i, resultValueState.basis);
|
||||
// Convert from unrolled vector-space offsets to element-space offsets.
|
||||
auto offsets = zipMap([](int64_t v1, int64_t v2) { return v1 * v2; },
|
||||
vectorOffsets, resultOperandState.unrolledShape);
|
||||
vectorOffsets, resultValueState.unrolledShape);
|
||||
res = builder.create<vector::InsertStridedSliceOp>(
|
||||
op->getLoc(), caches[indexMapListResultIndex][i], res, offsets,
|
||||
strides);
|
||||
op->getLoc(), caches[resultIndex][i], res, offsets, strides);
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
// Entry point for unrolling declarative pattern rewrites.
|
||||
// `op` is unrolled to the `targetShape` as follows, for each of its operands:
|
||||
// 1. the unrolled type `unrolledVectorType` and number of unrolled instances
|
||||
// `numUnrolledInstances` are computed from the `targetShape`. For now it is
|
||||
// assumed the unrolling factors divide the vector sizes.
|
||||
// 2. a fakeFork cast op is inserted that takes the operand and returns
|
||||
// `numUnrolledInstances` results of type `unrolledVectorType`.
|
||||
// 3. the original op is cloned `numUnrolledInstances` times, once for each
|
||||
// result of the fakeFork cast op.
|
||||
// 4. a fakeJoin cast op takes all these results and merges them into a single
|
||||
// aggregate vector result whose size matches the original non-unrolled op
|
||||
// operand types.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// opA(operand0, operand1) // numUnrolledInstances = 3
|
||||
//
|
||||
// operand0 operand1
|
||||
// | |
|
||||
// fork fork
|
||||
// <----------gather all fork ops --------->
|
||||
// /|\ /|\
|
||||
// f00 f01 f02 f10 f11 f12
|
||||
// <---------- clone op 3 times --------->
|
||||
// opA0(f00, f10), opA1(f01, f11), opA2(f02, f12)
|
||||
// \ | /
|
||||
// <-------------------- join ------------------------->
|
||||
//
|
||||
// Other local patterns then kick in iteratively (including DCE) and compose
|
||||
// until all the fakeFork and fakeJoin ops are removed.
|
||||
//
|
||||
// This will be extended in the future to support more advanced use cases than
|
||||
// simple pointwise ops.
|
||||
Value * mlir::vector::unrollSingleResultOpMatchingType(PatternRewriter &builder,
|
||||
Operation *op,
|
||||
ArrayRef<int64_t> targetShape) {
|
||||
if (auto contractionOp = dyn_cast<vector::ContractionOp>(op)) {
|
||||
// Get contraction op iteration bounds.
|
||||
SmallVector<int64_t, 6> iterationBounds;
|
||||
contractionOp.getIterationBounds(iterationBounds);
|
||||
assert(iterationBounds.size() == targetShape.size());
|
||||
// Get map from iteration space index to lhs/rhs/result shape index.
|
||||
std::vector<DenseMap<int64_t, int64_t>> iterationIndexMapList;
|
||||
contractionOp.getIterationIndexMap(iterationIndexMapList);
|
||||
if (llvm::size(contractionOp.masks()) == 2) {
|
||||
// Add maps for lhs/rhs vector mask arguments (same lhs/rhs vector shape)
|
||||
iterationIndexMapList.push_back(iterationIndexMapList[0]);
|
||||
iterationIndexMapList.push_back(iterationIndexMapList[1]);
|
||||
}
|
||||
// Unroll 'op' 'iterationBounds' to 'targetShape'.
|
||||
// TODO(andydavis) Use linalg style 'args_in'/'args_out' to partition
|
||||
// 'iterationIndexMapList' instead of 'indexMapListResultIndex'.
|
||||
return unrollSingleResultStructuredOp(
|
||||
op, iterationBounds, iterationIndexMapList,
|
||||
/*indexMapListResultIndex=*/2, targetShape, builder);
|
||||
}
|
||||
// TODO(andydavis) Create trivial iteration bounds and index map for
|
||||
// elementwise operations and call 'unrollSingleResultStructuredOp'. Remove
|
||||
// fakefork/join if possible.
|
||||
|
||||
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
|
||||
"]: unrollSingleResultOpMatchingType on func:\n");
|
||||
LLVM_DEBUG(op->getParentOfType<FuncOp>().print(dbgs()));
|
||||
if (!op->getNumResults())
|
||||
assert(false && "Use precondition till RewriterGen can act on nullptr");
|
||||
|
||||
auto shapedType = op->getResult(0)->getType().dyn_cast_or_null<ShapedType>();
|
||||
if (!shapedType || !shapedType.hasStaticShape())
|
||||
assert(false && "Use precondition till RewriterGen can act on nullptr");
|
||||
|
||||
auto shape = shapedType.getShape();
|
||||
auto maybeUnrollFactors = shapeRatio(shape, targetShape);
|
||||
if (!maybeUnrollFactors.hasValue())
|
||||
assert(false && "Use precondition till RewriterGen can act on nullptr");
|
||||
auto unrollFactors = *maybeUnrollFactors;
|
||||
|
||||
auto loc = op->getLoc();
|
||||
auto numUnrolledInstances = computeMaxLinearIndex(unrollFactors);
|
||||
auto unrolledVectorType =
|
||||
VectorType::get(targetShape, shapedType.getElementType());
|
||||
SmallVector<Type, 4> forkedType(numUnrolledInstances, unrolledVectorType);
|
||||
SmallVector<Operation *, 4> forkeds;
|
||||
forkeds.reserve(numUnrolledInstances);
|
||||
// Create a new forkOp for each operand.
|
||||
for (auto *operand : op->getOperands())
|
||||
forkeds.push_back(
|
||||
createFakeForkOp(builder, loc, operand, forkedType, unrollFactors));
|
||||
|
||||
SmallVector<Operation *, 4> newOps;
|
||||
newOps.reserve(numUnrolledInstances);
|
||||
for (int64_t idx = 0; idx < numUnrolledInstances; ++idx) {
|
||||
SmallVector<Value *, 4> operands;
|
||||
operands.reserve(forkeds.size());
|
||||
for (auto *fork : forkeds) {
|
||||
operands.push_back(fork->getResult(idx));
|
||||
}
|
||||
newOps.push_back(cloneOpWithOperandsAndTypes(builder, loc, op, operands,
|
||||
unrolledVectorType));
|
||||
static void getVectorContractionOpUnrollState(
|
||||
vector::ContractionOp contractionOp, ArrayRef<int64_t> targetShape,
|
||||
SmallVectorImpl<int64_t> &iterationBounds,
|
||||
std::vector<VectorState> &vectors, unsigned &resultIndex) {
|
||||
// Get contraction op iteration bounds.
|
||||
contractionOp.getIterationBounds(iterationBounds);
|
||||
assert(iterationBounds.size() == targetShape.size());
|
||||
// Get map from iteration space index to lhs/rhs/result shape index.
|
||||
std::vector<DenseMap<int64_t, int64_t>> iterationIndexMapList;
|
||||
contractionOp.getIterationIndexMap(iterationIndexMapList);
|
||||
unsigned numIterators = iterationIndexMapList.size();
|
||||
vectors.resize(numIterators);
|
||||
unsigned accOperandIndex = vector::ContractionOp::getAccOperandIndex();
|
||||
for (unsigned i = 0; i < numIterators; ++i) {
|
||||
vectors[i].type = contractionOp.getOperand(i)->getType().cast<VectorType>();
|
||||
vectors[i].indexMap = iterationIndexMapList[i];
|
||||
vectors[i].operandIndex = i;
|
||||
vectors[i].isAcc = i == accOperandIndex ? true : false;
|
||||
}
|
||||
|
||||
SmallVector<Value *, 4> newOpResults;
|
||||
newOpResults.reserve(newOps.size());
|
||||
for (auto *newOp : newOps)
|
||||
newOpResults.push_back(newOp->getResult(0));
|
||||
|
||||
return createFakeJoinOp(builder, loc, newOpResults, shapedType, unrollFactors,
|
||||
{0})
|
||||
->getResult(0);
|
||||
if (llvm::size(contractionOp.masks()) == 2) {
|
||||
// Add vectors for lhs/rhs vector mask arguments. Masks have the
|
||||
// same vector shape lhs/rhs args, so copy their index maps.
|
||||
vectors.push_back(
|
||||
{vectors[0].type, vectors[0].indexMap, accOperandIndex + 1, false});
|
||||
vectors.push_back(
|
||||
{vectors[1].type, vectors[1].indexMap, accOperandIndex + 2, false});
|
||||
}
|
||||
// Unroll 'op' 'iterationBounds' to 'targetShape'.
|
||||
// TODO(andydavis) Use linalg style 'args_in'/'args_out' to partition
|
||||
// 'vectors' instead of 'resultIndex'.
|
||||
resultIndex = accOperandIndex;
|
||||
}
|
||||
|
||||
// Patterns with this benefit just forwards arguments to clean up fake fork and
|
||||
// fake joins. It is a nicer and more direct cleanup when we can use it so it
|
||||
// kicks in with higher precedence.
|
||||
static constexpr int64_t kMatchingFakeForkFakeJoinBenefit = 1;
|
||||
static void
|
||||
getVectorElementwiseOpUnrollState(Operation *op, ArrayRef<int64_t> targetShape,
|
||||
SmallVectorImpl<int64_t> &iterationBounds,
|
||||
std::vector<VectorState> &vectors,
|
||||
unsigned &resultIndex) {
|
||||
// Verify that operation and operands all have the same vector shape.
|
||||
auto resultType = op->getResult(0)->getType().dyn_cast_or_null<VectorType>();
|
||||
assert(resultType && "Expected op with vector result type");
|
||||
auto resultShape = resultType.getShape();
|
||||
// Verify that all operands have the same vector type as result.
|
||||
assert(llvm::all_of(op->getOperandTypes(),
|
||||
[=](Type type) { return type == resultType; }));
|
||||
// Populate 'iterationBounds' with 'resultShape' for elementwise operations.
|
||||
iterationBounds.assign(resultShape.begin(), resultShape.end());
|
||||
|
||||
// Create trivial elementwise identity index map based on 'resultShape'.
|
||||
DenseMap<int64_t, int64_t> indexMap;
|
||||
indexMap.reserve(resultShape.size());
|
||||
for (unsigned i = 0; i < resultShape.size(); ++i)
|
||||
indexMap[i] = i;
|
||||
|
||||
// Create VectorState each operand and single result.
|
||||
unsigned numVectors = op->getNumOperands() + op->getNumResults();
|
||||
vectors.resize(numVectors);
|
||||
for (unsigned i = 0; i < op->getNumOperands(); ++i)
|
||||
vectors[i] = {resultType, indexMap, i, false};
|
||||
vectors[numVectors - 1] = {resultType, indexMap, -1, false};
|
||||
resultIndex = numVectors - 1;
|
||||
}
|
||||
|
||||
// Entry point for unrolling declarative pattern rewrites.
|
||||
Value *mlir::vector::unrollSingleResultOpMatchingType(
|
||||
PatternRewriter &builder, Operation *op, ArrayRef<int64_t> targetShape) {
|
||||
assert(op->getNumResults() == 1 && "Expected single result operation");
|
||||
|
||||
// Populate 'iterationBounds', 'vectors' and 'resultIndex' to unroll 'op'.
|
||||
SmallVector<int64_t, 6> iterationBounds;
|
||||
std::vector<VectorState> vectors;
|
||||
unsigned resultIndex;
|
||||
|
||||
if (auto contractionOp = dyn_cast<vector::ContractionOp>(op)) {
|
||||
// Popultate state for vector ContractionOp.
|
||||
getVectorContractionOpUnrollState(contractionOp, targetShape,
|
||||
iterationBounds, vectors, resultIndex);
|
||||
} else {
|
||||
// Populate state for vector elementwise op.
|
||||
getVectorElementwiseOpUnrollState(op, targetShape, iterationBounds, vectors,
|
||||
resultIndex);
|
||||
}
|
||||
|
||||
// Unroll 'op' with 'iterationBounds' to 'targetShape'.
|
||||
return unrollSingleResultStructuredOp(op, iterationBounds, vectors,
|
||||
resultIndex, targetShape, builder);
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
namespace vector {
|
||||
@ -514,177 +463,9 @@ namespace {
|
||||
} // end namespace vector
|
||||
} // end namespace mlir
|
||||
|
||||
// Match a fakeFork fed by a fakeJoin and just forward its operands.
|
||||
// This is akin to calling `replaceAllUsesOf` but made to play nice with all the
|
||||
// other RewritePattern.
|
||||
struct ConvertMatchingFakeForkFakeJoinOp : public RewritePattern {
|
||||
ConvertMatchingFakeForkFakeJoinOp(MLIRContext *context)
|
||||
// low-benefit to kick-in late
|
||||
: RewritePattern(kFakeForkOp, kMatchingFakeForkFakeJoinBenefit, context) {
|
||||
}
|
||||
|
||||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (op->getNumOperands() != 1)
|
||||
return matchFailure();
|
||||
|
||||
auto *definingOp = op->getOperand(0)->getDefiningOp();
|
||||
if (!definingOp || definingOp->getName().getStringRef() != kFakeJoinOp)
|
||||
return matchFailure();
|
||||
|
||||
if (definingOp->getNumOperands() != op->getNumResults())
|
||||
return matchFailure();
|
||||
|
||||
for (auto it : llvm::zip(definingOp->getOperands(), op->getResults())) {
|
||||
if (std::get<0>(it)->getType() != std::get<1>(it)->getType())
|
||||
return matchFailure();
|
||||
}
|
||||
|
||||
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
|
||||
"]: ConvertMatchingFakeForkFakeJoinOp on op: "
|
||||
<< *op << " in func:\n");
|
||||
LLVM_DEBUG(op->getParentOfType<FuncOp>().print(dbgs()));
|
||||
rewriter.replaceOp(op, definingOp->getOperands());
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
// Rewrites a fakeFork, whose (unique) operand is a blockArgument, into multiple
|
||||
// vector.strided_slice ops.
|
||||
struct ConvertFakeForkFromBlockArgsOrTransferReadOp : public RewritePattern {
|
||||
ConvertFakeForkFromBlockArgsOrTransferReadOp(MLIRContext *context)
|
||||
// low-benefit to kick-in late
|
||||
: RewritePattern(kFakeForkOp, 0, context) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (op->getNumOperands() != 1)
|
||||
return matchFailure();
|
||||
|
||||
if (op->use_empty()) {
|
||||
rewriter.eraseOp(op);
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
auto *arg = op->getOperand(0);
|
||||
if (!isa<BlockArgument>(arg) &&
|
||||
!isa<vector::TransferReadOp>(arg->getDefiningOp()))
|
||||
return matchFailure();
|
||||
|
||||
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
|
||||
"]: ConvertFakeForkFromBlockArgsOp on op: "
|
||||
<< *op << " in func:\n");
|
||||
LLVM_DEBUG(op->getParentOfType<FuncOp>().print(dbgs()));
|
||||
|
||||
// Look at the unroll factors remaining on this op and act on the first one.
|
||||
auto unrollFactorsStorage = extractUnrollFactors(op);
|
||||
ArrayRef<int64_t> unrollFactors{unrollFactorsStorage};
|
||||
if (unrollFactors.empty()) {
|
||||
// No more unrollFactors, just sanity check + forward the unique operand.
|
||||
assert(op->getNumResults() == 1);
|
||||
assert(arg->getType() == op->getResult(0)->getType());
|
||||
rewriter.replaceOp(op, arg);
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
// Strides are always 1 for now.
|
||||
// TODO(b/144845578) support non-1 strides.
|
||||
auto forkedVectorType = arg->getType().cast<VectorType>();
|
||||
SmallVector<int64_t, 4> strides(unrollFactors.size(), 1);
|
||||
auto nUnrolled = computeMaxLinearIndex(unrollFactors);
|
||||
SmallVector<Value *, 4> extractedVectors;
|
||||
extractedVectors.reserve(op->getNumResults());
|
||||
auto linearizationBasis = computeStrides(unrollFactors);
|
||||
for (unsigned idx = 0; idx < nUnrolled; ++idx) {
|
||||
auto offsets = delinearize(idx, linearizationBasis);
|
||||
offsets = zipMap([](int64_t v1, int64_t v2) { return v1 * v2; }, offsets,
|
||||
unrollFactors);
|
||||
auto leadingSize =
|
||||
forkedVectorType.getShape().take_front(unrollFactors.size());
|
||||
auto sizes = zipMap([](int64_t v1, int64_t v2) { return v1 / v2; },
|
||||
leadingSize, unrollFactors);
|
||||
extractedVectors.push_back(
|
||||
rewriter
|
||||
.create<vector::StridedSliceOp>(op->getLoc(), arg, offsets, sizes,
|
||||
strides)
|
||||
.getResult());
|
||||
}
|
||||
rewriter.replaceOp(op, extractedVectors);
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
// Rewrites a fakeJoin, whose (unique) operand is a blockArgument, into multiple
|
||||
// vector.strided_slice ops.
|
||||
struct ConvertFakeJoinOp : public RewritePattern {
|
||||
ConvertFakeJoinOp(MLIRContext *context)
|
||||
// low-benefit to kick-in late
|
||||
: RewritePattern(kFakeJoinOp, 0, context) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (op->getNumResults() != 1)
|
||||
return matchFailure();
|
||||
|
||||
if (op->use_empty()) {
|
||||
rewriter.eraseOp(op);
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
auto resultVectorType = op->getResult(0)->getType().cast<VectorType>();
|
||||
auto loc = op->getLoc();
|
||||
auto *res = makeSplatZero(loc, rewriter, resultVectorType);
|
||||
|
||||
auto unrollFactorsStorage = extractUnrollFactors(op);
|
||||
ArrayRef<int64_t> unrollFactors{unrollFactorsStorage};
|
||||
auto linearizationBasis = computeStrides(unrollFactors);
|
||||
auto nUnrolled = computeMaxLinearIndex(unrollFactors);
|
||||
SmallVector<int64_t, 4> strides(unrollFactors.size(), 1);
|
||||
for (unsigned idx = 0; idx < nUnrolled; ++idx) {
|
||||
auto offsets = delinearize(idx, linearizationBasis);
|
||||
offsets = zipMap([](int64_t v1, int64_t v2) { return v1 * v2; }, offsets,
|
||||
unrollFactors);
|
||||
res = rewriter.create<vector::InsertStridedSliceOp>(
|
||||
loc, op->getOperand(idx), res, offsets, strides);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, res);
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
// Simple DCE for fakeForkOps/fakeJoinOps, we do not want them to escape a
|
||||
// transformation (otherwise the transformation is considered incorrect).
|
||||
struct FakeForkTrait {
|
||||
static constexpr char const *name = kFakeForkOp;
|
||||
};
|
||||
struct FakeJoinTrait {
|
||||
static constexpr char const *name = kFakeJoinOp;
|
||||
};
|
||||
|
||||
template <typename OpNameTrait> struct DCEPattern : public RewritePattern {
|
||||
DCEPattern(MLIRContext *context)
|
||||
// low-benefit to kick-in late
|
||||
: RewritePattern(OpNameTrait::name, 0, context) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
assert(op->getName().getStringRef() == kFakeForkOp ||
|
||||
op->getName().getStringRef() == kFakeJoinOp);
|
||||
if (!op->use_empty())
|
||||
return matchFailure();
|
||||
rewriter.eraseOp(op);
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
void mlir::populateVectorToVectorConversionPatterns(
|
||||
MLIRContext *context, OwningRewritePatternList &patterns,
|
||||
ArrayRef<int64_t> coarseVectorShape, ArrayRef<int64_t> fineVectorShape) {
|
||||
vector::populateWithGenerated(context, &patterns);
|
||||
vector::populateVectorToVectorCanonicalizationPatterns(patterns, context);
|
||||
patterns
|
||||
.insert<ConvertMatchingFakeForkFakeJoinOp,
|
||||
ConvertFakeForkFromBlockArgsOrTransferReadOp, ConvertFakeJoinOp,
|
||||
DCEPattern<FakeForkTrait>, DCEPattern<FakeJoinTrait>>(context);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user