From 900bf75eb5b2f77447bcc315a1d76882f72b0f2e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 9 Dec 2019 09:34:40 -0800 Subject: [PATCH] Unify vector op unrolling transformation. Unifies vector op unrolling transformation, by using the same unrolling implementation for contraction and elementwise operations. Removes fakefork/join operations which are non longer needed now that we have the InsertStridedSlice operation. PiperOrigin-RevId: 284570784 Change-Id: Ic3412bc6456d91bc24afda25bbe98888ff4d5849 --- .../mlir/Dialect/VectorOps/VectorOps.td | 1 + .../lib/Dialect/VectorOps/VectorToVector.cpp | 577 ++++++------------ 2 files changed, 180 insertions(+), 398 deletions(-) diff --git a/third_party/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/third_party/mlir/include/mlir/Dialect/VectorOps/VectorOps.td index 172897543e4..2210e8bb923 100644 --- a/third_party/mlir/include/mlir/Dialect/VectorOps/VectorOps.td +++ b/third_party/mlir/include/mlir/Dialect/VectorOps/VectorOps.td @@ -165,6 +165,7 @@ def Vector_ContractionOp : static StringRef getParallelIteratorTypeName() { return "parallel"; } + static unsigned getAccOperandIndex() { return 2; } // Returns the bounds of each dimension in the iteration space spanned // by the iterator types of this operation. diff --git a/third_party/mlir/lib/Dialect/VectorOps/VectorToVector.cpp b/third_party/mlir/lib/Dialect/VectorOps/VectorToVector.cpp index 82d19f5efc5..1acac63602c 100644 --- a/third_party/mlir/lib/Dialect/VectorOps/VectorToVector.cpp +++ b/third_party/mlir/lib/Dialect/VectorOps/VectorToVector.cpp @@ -102,58 +102,6 @@ static SmallVector delinearize(int64_t linearIndex, return res; } -static constexpr auto kFakeForkOp = "__fake_fork__"; -static constexpr auto kFakeJoinOp = "__fake_join__"; -static constexpr auto kUnrollAttrName = "__unroll__"; -static constexpr auto kBaseCoordAttrName = "__base_coord__"; - -// Reads the IntegerArray attribute named `kUnrollAttrName` from `op` and -// returns its representation as a vector of integers. -static SmallVector extractUnrollFactors(Operation *op) { - SmallVector res; - auto unrollAttr = op->getAttr(kUnrollAttrName); - if (!unrollAttr) - return res; - auto unrollArrayAttr = unrollAttr.cast(); - res.reserve(unrollArrayAttr.size()); - for (auto attr : unrollArrayAttr) { - auto unroll = attr.cast().getValue().getSExtValue(); - assert(unroll > 0); - res.push_back(unroll); - } - return res; -} - -// Creates a custom `kFakeForkOp` used in progressive lowering to other vector -// operations. -static Operation *createFakeForkOp(PatternRewriter &builder, Location loc, - Value *operand, ArrayRef resultTypes, - ArrayRef unrollFactors = {}) { - OperationState *forkOp = - new OperationState(loc, kFakeForkOp, operand, resultTypes, {}); - if (!unrollFactors.empty()) - forkOp->addAttribute(kUnrollAttrName, - builder.getI64ArrayAttr(unrollFactors)); - return builder.createOperation(*forkOp); -} - -// Creates a custom `kFakeJoinOp` used in progressive lowering to other vector -// operations. -static Operation *createFakeJoinOp(PatternRewriter &builder, Location loc, - ArrayRef operands, Type resultType, - ArrayRef unrollFactors = {}, - ArrayRef baseCoords = {}) { - OperationState *joinOp = - new OperationState(loc, kFakeJoinOp, operands, resultType, {}); - if (!unrollFactors.empty()) - joinOp->addAttribute(kUnrollAttrName, - builder.getI64ArrayAttr(unrollFactors)); - if (!baseCoords.empty()) - joinOp->addAttribute(kBaseCoordAttrName, - builder.getI64ArrayAttr(baseCoords)); - return builder.createOperation(*joinOp); -} - // Clones `op` into a new operations that takes `operands` and returns // `resultTypes`. static Operation *cloneOpWithOperandsAndTypes(PatternRewriter &builder, @@ -202,9 +150,9 @@ static void getMappedElements(const DenseMap &indexMap, } } -// UnrolledOperandState aggregates per-operand state required for op unrolling. -struct UnrolledOperandState { - Value *operand; +// UnrolledVectorState aggregates per-operand/result vector state required for +// unrolling. +struct UnrolledVectorState { SmallVector unrolledShape; SmallVector unrollFactors; SmallVector basis; @@ -212,14 +160,12 @@ struct UnrolledOperandState { }; // Populates 'state' with unrolled shape, unroll factors, basis and -// num unrolled instances for 'operand'. -static void getUnrolledOperandState(Value *operand, +// num unrolled instances for 'vectorType'. +static void initUnrolledVectorState(VectorType vectorType, const DenseMap &indexMap, ArrayRef targetShape, - UnrolledOperandState &state) { - auto vectorType = operand->getType().cast(); - state.operand = operand; - // Compute unrolled shape of 'operand'. + UnrolledVectorState &state) { + // Compute unrolled shape of 'vectorType'. state.unrolledShape.resize(vectorType.getRank()); getMappedElements(indexMap, targetShape, state.unrolledShape); // Compute unroll factors for unrolled shape. @@ -233,53 +179,72 @@ static void getUnrolledOperandState(Value *operand, } // Computes and returns the linear index of the unrolled vector at -// 'vectorOffsets' within the vector operand represented by 'state'. +// 'vectorOffsets' within the vector represented by 'state'. static int64_t -getUnrolledOperandLinearIndex(UnrolledOperandState &state, - ArrayRef vectorOffsets, - DenseMap &indexMap) { - // Compute operand offsets. +getUnrolledVectorLinearIndex(UnrolledVectorState &state, + ArrayRef vectorOffsets, + DenseMap &indexMap) { + // Compute vector offsets. SmallVector sliceOffsets(state.unrolledShape.size()); getMappedElements(indexMap, vectorOffsets, sliceOffsets); // Compute and return linear index of 'sliceOffsets' w.r.t 'state.basis'. return linearize(sliceOffsets, state.basis); } -// Returns an unrolled vector at 'vectorOffsets' within the vector operand -// represented by 'state'. The value is created if not present in 'cache'. -static Value *getOrCreateUnrolledOperandSlice( - Location loc, UnrolledOperandState &state, ArrayRef vectorOffsets, +// Returns an unrolled vector at 'vectorOffsets' within the vector +// represented by 'state'. The vector is created from a slice of 'initValue' +// if not present in 'cache'. +static Value *getOrCreateUnrolledVectorSlice( + Location loc, UnrolledVectorState &state, ArrayRef vectorOffsets, ArrayRef offsets, DenseMap &indexMap, - SmallVectorImpl &cache, PatternRewriter &builder) { - // Compute operand offsets. + Value *initValue, SmallVectorImpl &cache, + PatternRewriter &builder) { + // Compute slice offsets. SmallVector sliceOffsets(state.unrolledShape.size()); getMappedElements(indexMap, offsets, sliceOffsets); // TODO(b/144845578) Support non-1 strides. SmallVector sliceStrides(state.unrolledShape.size(), 1); // Compute linear index of 'sliceOffsets' w.r.t 'state.basis'. int64_t sliceLinearIndex = - getUnrolledOperandLinearIndex(state, vectorOffsets, indexMap); + getUnrolledVectorLinearIndex(state, vectorOffsets, indexMap); assert(sliceLinearIndex < static_cast(cache.size())); - auto *operandSlice = cache[sliceLinearIndex]; - if (operandSlice == nullptr) { - // Initialize 'cache' with slice from 'state.operand'. - operandSlice = builder.create( - loc, state.operand, sliceOffsets, state.unrolledShape, sliceStrides); + auto *valueSlice = cache[sliceLinearIndex]; + if (valueSlice == nullptr) { + assert(initValue != nullptr); + // Initialize 'cache' with slice from 'state.value'. + valueSlice = builder.create( + loc, initValue, sliceOffsets, state.unrolledShape, sliceStrides); // Store value back to 'cache'. - cache[sliceLinearIndex] = operandSlice; + cache[sliceLinearIndex] = valueSlice; } - return operandSlice; + return valueSlice; } +// VectorState aggregates per-operand/result vector state required for +// creating slices of vector operands, and clones of the operation being +// unrolled. +struct VectorState { + // The type of this vector. + VectorType type; + // Map from iteration space index to vector dimension index. + DenseMap indexMap; + // Index of this value in operation's operand list (-1 if not an operand). + int64_t operandIndex = -1; + // Accumulator iterator flag. + bool isAcc = false; +}; + // // unrollSingleResultStructuredOp // // Returns a value representing the result of structured operation 'op' // with iteration bounds 'iterationBounds' unrolled to 'targetShape'. -// An iteration space index map argument 'iterationIndexMapList' must be -// specified, with a map for each structured op input and a single map for the -// single result. The map at index 'indexMapListResultIndex' in the list must -// be the single result map. +// A list of VectorState objects must be specified in 'vectors', where +// each VectorState in the list represents a vector operand or vector result +// (if the operation does not have an accumulator operand). +// The VectorState at index 'resultIndex' in the list must be the state +// associated with the operations single result (i.e. either its accumulator +// operand or vector result value). // // Example: // @@ -304,13 +269,24 @@ static Value *getOrCreateUnrolledOperandSlice( // insertslice // | +// TODO(andydavis) Add the following canonicalization/simplifcation patterns: +// *) Add pattern which matches InsertStridedSlice -> StridedSlice and forwards +// InsertStridedSlice operand to StridedSlice. +// *) Add pattern which matches SourceOp -> StridedSlice -> UserOp which checks +// if there are duplicate identical StridedSlice ops from SourceOp, and +// rewrites itself to use the first duplicate. This transformation should +// cause users of identifical StridedSlice ops to reuse the same StridedSlice +// operation, and leave the duplicate StridedSlice ops with no users +// (removable with DCE). + // TODO(andydavis) Generalize this to support structured ops beyond // vector ContractionOp, and merge it with 'unrollSingleResultOpMatchingType' -static Value *unrollSingleResultStructuredOp( - Operation *op, ArrayRef iterationBounds, - std::vector> &iterationIndexMapList, - unsigned indexMapListResultIndex, ArrayRef targetShape, - PatternRewriter &builder) { +static Value *unrollSingleResultStructuredOp(Operation *op, + ArrayRef iterationBounds, + std::vector &vectors, + unsigned resultIndex, + ArrayRef targetShape, + PatternRewriter &builder) { auto shapedType = op->getResult(0)->getType().dyn_cast_or_null(); if (!shapedType || !shapedType.hasStaticShape()) assert(false && "Expected a statically shaped result type"); @@ -321,27 +297,25 @@ static Value *unrollSingleResultStructuredOp( assert(false && "Failed to compute unroll factors for target shape"); auto unrollFactors = *maybeUnrollFactors; - // Compute unrolled operation state for each mapped operand. - unsigned numMaps = iterationIndexMapList.size(); - SmallVector unrolledOperandState(numMaps); - assert(op->getNumOperands() >= numMaps); - for (unsigned i = 0; i < numMaps; ++i) { - getUnrolledOperandState(op->getOperand(i), iterationIndexMapList[i], - targetShape, unrolledOperandState[i]); + // Compute unrolled vector state for each vector in 'vectors'. + unsigned numVectors = vectors.size(); + SmallVector unrolledVectorState(numVectors); + for (unsigned i = 0; i < numVectors; ++i) { + initUnrolledVectorState(vectors[i].type, vectors[i].indexMap, targetShape, + unrolledVectorState[i]); } // Compute number of total unrolled instances. auto numUnrolledInstances = computeMaxLinearIndex(unrollFactors); auto basis = computeStrides(unrollFactors); - auto &resultOperandState = unrolledOperandState[indexMapListResultIndex]; - auto unrolledResultType = VectorType::get(resultOperandState.unrolledShape, + auto &resultValueState = unrolledVectorState[resultIndex]; + auto unrolledResultType = VectorType::get(resultValueState.unrolledShape, shapedType.getElementType()); // Initialize caches for intermediate vector results. - std::vector> caches(numMaps); - for (unsigned i = 0; i < numMaps; ++i) { - caches[i].resize(unrolledOperandState[i].numInstances); - } + std::vector> caches(numVectors); + for (unsigned i = 0; i < numVectors; ++i) + caches[i].resize(unrolledVectorState[i].numInstances); // Unroll 'numUnrolledInstances' of 'op', storing results in 'caches'. for (unsigned i = 0; i < numUnrolledInstances; ++i) { @@ -352,11 +326,15 @@ static Value *unrollSingleResultStructuredOp( vectorOffsets, targetShape); // Get cached slice (or create slice) for each operand at 'offsets'. SmallVector operands; - operands.reserve(numMaps); - for (unsigned i = 0; i < numMaps; ++i) { - operands.push_back(getOrCreateUnrolledOperandSlice( - op->getLoc(), unrolledOperandState[i], vectorOffsets, offsets, - iterationIndexMapList[i], caches[i], builder)); + operands.resize(op->getNumOperands()); + for (unsigned i = 0; i < numVectors; ++i) { + int64_t operandIndex = vectors[i].operandIndex; + if (operandIndex < 0) + continue; // Output + auto *operand = op->getOperand(operandIndex); + operands[operandIndex] = getOrCreateUnrolledVectorSlice( + op->getLoc(), unrolledVectorState[i], vectorOffsets, offsets, + vectors[i].indexMap, operand, caches[i], builder); } // Create op on sliced vector arguments. auto resultVector = @@ -365,146 +343,117 @@ static Value *unrollSingleResultStructuredOp( ->getResult(0); // Compute linear result index. - int64_t resultIndex = getUnrolledOperandLinearIndex( - resultOperandState, vectorOffsets, - iterationIndexMapList[indexMapListResultIndex]); - // Update result cache at 'resultIndex'. - caches[indexMapListResultIndex][resultIndex] = resultVector; + int64_t linearIndex = getUnrolledVectorLinearIndex( + resultValueState, vectorOffsets, vectors[resultIndex].indexMap); + // Update result cache at 'linearIndex'. + caches[resultIndex][linearIndex] = resultVector; } // Make zero splat into which we will insert results from - // 'cache[indexMapListResultIndex]' + // 'cache[resultIndex]' auto resultVectorType = op->getResult(0)->getType().cast(); auto *res = makeSplatZero(op->getLoc(), builder, resultVectorType); - SmallVector strides(resultOperandState.unrollFactors.size(), 1); + SmallVector strides(resultValueState.unrollFactors.size(), 1); // Insert vector accumulators into output. - for (unsigned i = 0; i < resultOperandState.numInstances; ++i) { - auto vectorOffsets = delinearize(i, resultOperandState.basis); + for (unsigned i = 0; i < resultValueState.numInstances; ++i) { + auto vectorOffsets = delinearize(i, resultValueState.basis); // Convert from unrolled vector-space offsets to element-space offsets. auto offsets = zipMap([](int64_t v1, int64_t v2) { return v1 * v2; }, - vectorOffsets, resultOperandState.unrolledShape); + vectorOffsets, resultValueState.unrolledShape); res = builder.create( - op->getLoc(), caches[indexMapListResultIndex][i], res, offsets, - strides); + op->getLoc(), caches[resultIndex][i], res, offsets, strides); } - return res; } -// Entry point for unrolling declarative pattern rewrites. -// `op` is unrolled to the `targetShape` as follows, for each of its operands: -// 1. the unrolled type `unrolledVectorType` and number of unrolled instances -// `numUnrolledInstances` are computed from the `targetShape`. For now it is -// assumed the unrolling factors divide the vector sizes. -// 2. a fakeFork cast op is inserted that takes the operand and returns -// `numUnrolledInstances` results of type `unrolledVectorType`. -// 3. the original op is cloned `numUnrolledInstances` times, once for each -// result of the fakeFork cast op. -// 4. a fakeJoin cast op takes all these results and merges them into a single -// aggregate vector result whose size matches the original non-unrolled op -// operand types. -// -// Example: -// -// opA(operand0, operand1) // numUnrolledInstances = 3 -// -// operand0 operand1 -// | | -// fork fork -// <----------gather all fork ops ---------> -// /|\ /|\ -// f00 f01 f02 f10 f11 f12 -// <---------- clone op 3 times ---------> -// opA0(f00, f10), opA1(f01, f11), opA2(f02, f12) -// \ | / -// <-------------------- join -------------------------> -// -// Other local patterns then kick in iteratively (including DCE) and compose -// until all the fakeFork and fakeJoin ops are removed. -// -// This will be extended in the future to support more advanced use cases than -// simple pointwise ops. -Value * mlir::vector::unrollSingleResultOpMatchingType(PatternRewriter &builder, - Operation *op, - ArrayRef targetShape) { - if (auto contractionOp = dyn_cast(op)) { - // Get contraction op iteration bounds. - SmallVector iterationBounds; - contractionOp.getIterationBounds(iterationBounds); - assert(iterationBounds.size() == targetShape.size()); - // Get map from iteration space index to lhs/rhs/result shape index. - std::vector> iterationIndexMapList; - contractionOp.getIterationIndexMap(iterationIndexMapList); - if (llvm::size(contractionOp.masks()) == 2) { - // Add maps for lhs/rhs vector mask arguments (same lhs/rhs vector shape) - iterationIndexMapList.push_back(iterationIndexMapList[0]); - iterationIndexMapList.push_back(iterationIndexMapList[1]); - } - // Unroll 'op' 'iterationBounds' to 'targetShape'. - // TODO(andydavis) Use linalg style 'args_in'/'args_out' to partition - // 'iterationIndexMapList' instead of 'indexMapListResultIndex'. - return unrollSingleResultStructuredOp( - op, iterationBounds, iterationIndexMapList, - /*indexMapListResultIndex=*/2, targetShape, builder); - } - // TODO(andydavis) Create trivial iteration bounds and index map for - // elementwise operations and call 'unrollSingleResultStructuredOp'. Remove - // fakefork/join if possible. - - LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE - "]: unrollSingleResultOpMatchingType on func:\n"); - LLVM_DEBUG(op->getParentOfType().print(dbgs())); - if (!op->getNumResults()) - assert(false && "Use precondition till RewriterGen can act on nullptr"); - - auto shapedType = op->getResult(0)->getType().dyn_cast_or_null(); - if (!shapedType || !shapedType.hasStaticShape()) - assert(false && "Use precondition till RewriterGen can act on nullptr"); - - auto shape = shapedType.getShape(); - auto maybeUnrollFactors = shapeRatio(shape, targetShape); - if (!maybeUnrollFactors.hasValue()) - assert(false && "Use precondition till RewriterGen can act on nullptr"); - auto unrollFactors = *maybeUnrollFactors; - - auto loc = op->getLoc(); - auto numUnrolledInstances = computeMaxLinearIndex(unrollFactors); - auto unrolledVectorType = - VectorType::get(targetShape, shapedType.getElementType()); - SmallVector forkedType(numUnrolledInstances, unrolledVectorType); - SmallVector forkeds; - forkeds.reserve(numUnrolledInstances); - // Create a new forkOp for each operand. - for (auto *operand : op->getOperands()) - forkeds.push_back( - createFakeForkOp(builder, loc, operand, forkedType, unrollFactors)); - - SmallVector newOps; - newOps.reserve(numUnrolledInstances); - for (int64_t idx = 0; idx < numUnrolledInstances; ++idx) { - SmallVector operands; - operands.reserve(forkeds.size()); - for (auto *fork : forkeds) { - operands.push_back(fork->getResult(idx)); - } - newOps.push_back(cloneOpWithOperandsAndTypes(builder, loc, op, operands, - unrolledVectorType)); +static void getVectorContractionOpUnrollState( + vector::ContractionOp contractionOp, ArrayRef targetShape, + SmallVectorImpl &iterationBounds, + std::vector &vectors, unsigned &resultIndex) { + // Get contraction op iteration bounds. + contractionOp.getIterationBounds(iterationBounds); + assert(iterationBounds.size() == targetShape.size()); + // Get map from iteration space index to lhs/rhs/result shape index. + std::vector> iterationIndexMapList; + contractionOp.getIterationIndexMap(iterationIndexMapList); + unsigned numIterators = iterationIndexMapList.size(); + vectors.resize(numIterators); + unsigned accOperandIndex = vector::ContractionOp::getAccOperandIndex(); + for (unsigned i = 0; i < numIterators; ++i) { + vectors[i].type = contractionOp.getOperand(i)->getType().cast(); + vectors[i].indexMap = iterationIndexMapList[i]; + vectors[i].operandIndex = i; + vectors[i].isAcc = i == accOperandIndex ? true : false; } - SmallVector newOpResults; - newOpResults.reserve(newOps.size()); - for (auto *newOp : newOps) - newOpResults.push_back(newOp->getResult(0)); - - return createFakeJoinOp(builder, loc, newOpResults, shapedType, unrollFactors, - {0}) - ->getResult(0); + if (llvm::size(contractionOp.masks()) == 2) { + // Add vectors for lhs/rhs vector mask arguments. Masks have the + // same vector shape lhs/rhs args, so copy their index maps. + vectors.push_back( + {vectors[0].type, vectors[0].indexMap, accOperandIndex + 1, false}); + vectors.push_back( + {vectors[1].type, vectors[1].indexMap, accOperandIndex + 2, false}); + } + // Unroll 'op' 'iterationBounds' to 'targetShape'. + // TODO(andydavis) Use linalg style 'args_in'/'args_out' to partition + // 'vectors' instead of 'resultIndex'. + resultIndex = accOperandIndex; } -// Patterns with this benefit just forwards arguments to clean up fake fork and -// fake joins. It is a nicer and more direct cleanup when we can use it so it -// kicks in with higher precedence. -static constexpr int64_t kMatchingFakeForkFakeJoinBenefit = 1; +static void +getVectorElementwiseOpUnrollState(Operation *op, ArrayRef targetShape, + SmallVectorImpl &iterationBounds, + std::vector &vectors, + unsigned &resultIndex) { + // Verify that operation and operands all have the same vector shape. + auto resultType = op->getResult(0)->getType().dyn_cast_or_null(); + assert(resultType && "Expected op with vector result type"); + auto resultShape = resultType.getShape(); + // Verify that all operands have the same vector type as result. + assert(llvm::all_of(op->getOperandTypes(), + [=](Type type) { return type == resultType; })); + // Populate 'iterationBounds' with 'resultShape' for elementwise operations. + iterationBounds.assign(resultShape.begin(), resultShape.end()); + + // Create trivial elementwise identity index map based on 'resultShape'. + DenseMap indexMap; + indexMap.reserve(resultShape.size()); + for (unsigned i = 0; i < resultShape.size(); ++i) + indexMap[i] = i; + + // Create VectorState each operand and single result. + unsigned numVectors = op->getNumOperands() + op->getNumResults(); + vectors.resize(numVectors); + for (unsigned i = 0; i < op->getNumOperands(); ++i) + vectors[i] = {resultType, indexMap, i, false}; + vectors[numVectors - 1] = {resultType, indexMap, -1, false}; + resultIndex = numVectors - 1; +} + +// Entry point for unrolling declarative pattern rewrites. +Value *mlir::vector::unrollSingleResultOpMatchingType( + PatternRewriter &builder, Operation *op, ArrayRef targetShape) { + assert(op->getNumResults() == 1 && "Expected single result operation"); + + // Populate 'iterationBounds', 'vectors' and 'resultIndex' to unroll 'op'. + SmallVector iterationBounds; + std::vector vectors; + unsigned resultIndex; + + if (auto contractionOp = dyn_cast(op)) { + // Popultate state for vector ContractionOp. + getVectorContractionOpUnrollState(contractionOp, targetShape, + iterationBounds, vectors, resultIndex); + } else { + // Populate state for vector elementwise op. + getVectorElementwiseOpUnrollState(op, targetShape, iterationBounds, vectors, + resultIndex); + } + + // Unroll 'op' with 'iterationBounds' to 'targetShape'. + return unrollSingleResultStructuredOp(op, iterationBounds, vectors, + resultIndex, targetShape, builder); +} namespace mlir { namespace vector { @@ -514,177 +463,9 @@ namespace { } // end namespace vector } // end namespace mlir -// Match a fakeFork fed by a fakeJoin and just forward its operands. -// This is akin to calling `replaceAllUsesOf` but made to play nice with all the -// other RewritePattern. -struct ConvertMatchingFakeForkFakeJoinOp : public RewritePattern { - ConvertMatchingFakeForkFakeJoinOp(MLIRContext *context) - // low-benefit to kick-in late - : RewritePattern(kFakeForkOp, kMatchingFakeForkFakeJoinBenefit, context) { - } - - PatternMatchResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override { - if (op->getNumOperands() != 1) - return matchFailure(); - - auto *definingOp = op->getOperand(0)->getDefiningOp(); - if (!definingOp || definingOp->getName().getStringRef() != kFakeJoinOp) - return matchFailure(); - - if (definingOp->getNumOperands() != op->getNumResults()) - return matchFailure(); - - for (auto it : llvm::zip(definingOp->getOperands(), op->getResults())) { - if (std::get<0>(it)->getType() != std::get<1>(it)->getType()) - return matchFailure(); - } - - LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE - "]: ConvertMatchingFakeForkFakeJoinOp on op: " - << *op << " in func:\n"); - LLVM_DEBUG(op->getParentOfType().print(dbgs())); - rewriter.replaceOp(op, definingOp->getOperands()); - return matchSuccess(); - } -}; - -// Rewrites a fakeFork, whose (unique) operand is a blockArgument, into multiple -// vector.strided_slice ops. -struct ConvertFakeForkFromBlockArgsOrTransferReadOp : public RewritePattern { - ConvertFakeForkFromBlockArgsOrTransferReadOp(MLIRContext *context) - // low-benefit to kick-in late - : RewritePattern(kFakeForkOp, 0, context) {} - - PatternMatchResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override { - if (op->getNumOperands() != 1) - return matchFailure(); - - if (op->use_empty()) { - rewriter.eraseOp(op); - return matchSuccess(); - } - - auto *arg = op->getOperand(0); - if (!isa(arg) && - !isa(arg->getDefiningOp())) - return matchFailure(); - - LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE - "]: ConvertFakeForkFromBlockArgsOp on op: " - << *op << " in func:\n"); - LLVM_DEBUG(op->getParentOfType().print(dbgs())); - - // Look at the unroll factors remaining on this op and act on the first one. - auto unrollFactorsStorage = extractUnrollFactors(op); - ArrayRef unrollFactors{unrollFactorsStorage}; - if (unrollFactors.empty()) { - // No more unrollFactors, just sanity check + forward the unique operand. - assert(op->getNumResults() == 1); - assert(arg->getType() == op->getResult(0)->getType()); - rewriter.replaceOp(op, arg); - return matchSuccess(); - } - - // Strides are always 1 for now. - // TODO(b/144845578) support non-1 strides. - auto forkedVectorType = arg->getType().cast(); - SmallVector strides(unrollFactors.size(), 1); - auto nUnrolled = computeMaxLinearIndex(unrollFactors); - SmallVector extractedVectors; - extractedVectors.reserve(op->getNumResults()); - auto linearizationBasis = computeStrides(unrollFactors); - for (unsigned idx = 0; idx < nUnrolled; ++idx) { - auto offsets = delinearize(idx, linearizationBasis); - offsets = zipMap([](int64_t v1, int64_t v2) { return v1 * v2; }, offsets, - unrollFactors); - auto leadingSize = - forkedVectorType.getShape().take_front(unrollFactors.size()); - auto sizes = zipMap([](int64_t v1, int64_t v2) { return v1 / v2; }, - leadingSize, unrollFactors); - extractedVectors.push_back( - rewriter - .create(op->getLoc(), arg, offsets, sizes, - strides) - .getResult()); - } - rewriter.replaceOp(op, extractedVectors); - return matchSuccess(); - } -}; - -// Rewrites a fakeJoin, whose (unique) operand is a blockArgument, into multiple -// vector.strided_slice ops. -struct ConvertFakeJoinOp : public RewritePattern { - ConvertFakeJoinOp(MLIRContext *context) - // low-benefit to kick-in late - : RewritePattern(kFakeJoinOp, 0, context) {} - - PatternMatchResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override { - if (op->getNumResults() != 1) - return matchFailure(); - - if (op->use_empty()) { - rewriter.eraseOp(op); - return matchSuccess(); - } - - auto resultVectorType = op->getResult(0)->getType().cast(); - auto loc = op->getLoc(); - auto *res = makeSplatZero(loc, rewriter, resultVectorType); - - auto unrollFactorsStorage = extractUnrollFactors(op); - ArrayRef unrollFactors{unrollFactorsStorage}; - auto linearizationBasis = computeStrides(unrollFactors); - auto nUnrolled = computeMaxLinearIndex(unrollFactors); - SmallVector strides(unrollFactors.size(), 1); - for (unsigned idx = 0; idx < nUnrolled; ++idx) { - auto offsets = delinearize(idx, linearizationBasis); - offsets = zipMap([](int64_t v1, int64_t v2) { return v1 * v2; }, offsets, - unrollFactors); - res = rewriter.create( - loc, op->getOperand(idx), res, offsets, strides); - } - - rewriter.replaceOp(op, res); - return matchSuccess(); - } -}; - -// Simple DCE for fakeForkOps/fakeJoinOps, we do not want them to escape a -// transformation (otherwise the transformation is considered incorrect). -struct FakeForkTrait { - static constexpr char const *name = kFakeForkOp; -}; -struct FakeJoinTrait { - static constexpr char const *name = kFakeJoinOp; -}; - -template struct DCEPattern : public RewritePattern { - DCEPattern(MLIRContext *context) - // low-benefit to kick-in late - : RewritePattern(OpNameTrait::name, 0, context) {} - - PatternMatchResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override { - assert(op->getName().getStringRef() == kFakeForkOp || - op->getName().getStringRef() == kFakeJoinOp); - if (!op->use_empty()) - return matchFailure(); - rewriter.eraseOp(op); - return matchSuccess(); - } -}; - void mlir::populateVectorToVectorConversionPatterns( MLIRContext *context, OwningRewritePatternList &patterns, ArrayRef coarseVectorShape, ArrayRef fineVectorShape) { vector::populateWithGenerated(context, &patterns); vector::populateVectorToVectorCanonicalizationPatterns(patterns, context); - patterns - .insert, DCEPattern>(context); }