Unify vector op unrolling transformation.

Unifies vector op unrolling transformation, by using the same unrolling implementation for contraction and elementwise operations.
Removes fakefork/join operations which are non longer needed now that we have the InsertStridedSlice operation.

PiperOrigin-RevId: 284570784
Change-Id: Ic3412bc6456d91bc24afda25bbe98888ff4d5849
This commit is contained in:
A. Unique TensorFlower 2019-12-09 09:34:40 -08:00 committed by TensorFlower Gardener
parent 27a98ccab4
commit 900bf75eb5
2 changed files with 180 additions and 398 deletions

View File

@ -165,6 +165,7 @@ def Vector_ContractionOp :
static StringRef getParallelIteratorTypeName() {
return "parallel";
}
static unsigned getAccOperandIndex() { return 2; }
// Returns the bounds of each dimension in the iteration space spanned
// by the iterator types of this operation.

View File

@ -102,58 +102,6 @@ static SmallVector<int64_t, 8> delinearize(int64_t linearIndex,
return res;
}
static constexpr auto kFakeForkOp = "__fake_fork__";
static constexpr auto kFakeJoinOp = "__fake_join__";
static constexpr auto kUnrollAttrName = "__unroll__";
static constexpr auto kBaseCoordAttrName = "__base_coord__";
// Reads the IntegerArray attribute named `kUnrollAttrName` from `op` and
// returns its representation as a vector of integers.
static SmallVector<int64_t, 8> extractUnrollFactors(Operation *op) {
SmallVector<int64_t, 8> res;
auto unrollAttr = op->getAttr(kUnrollAttrName);
if (!unrollAttr)
return res;
auto unrollArrayAttr = unrollAttr.cast<ArrayAttr>();
res.reserve(unrollArrayAttr.size());
for (auto attr : unrollArrayAttr) {
auto unroll = attr.cast<IntegerAttr>().getValue().getSExtValue();
assert(unroll > 0);
res.push_back(unroll);
}
return res;
}
// Creates a custom `kFakeForkOp` used in progressive lowering to other vector
// operations.
static Operation *createFakeForkOp(PatternRewriter &builder, Location loc,
Value *operand, ArrayRef<Type> resultTypes,
ArrayRef<int64_t> unrollFactors = {}) {
OperationState *forkOp =
new OperationState(loc, kFakeForkOp, operand, resultTypes, {});
if (!unrollFactors.empty())
forkOp->addAttribute(kUnrollAttrName,
builder.getI64ArrayAttr(unrollFactors));
return builder.createOperation(*forkOp);
}
// Creates a custom `kFakeJoinOp` used in progressive lowering to other vector
// operations.
static Operation *createFakeJoinOp(PatternRewriter &builder, Location loc,
ArrayRef<Value *> operands, Type resultType,
ArrayRef<int64_t> unrollFactors = {},
ArrayRef<int64_t> baseCoords = {}) {
OperationState *joinOp =
new OperationState(loc, kFakeJoinOp, operands, resultType, {});
if (!unrollFactors.empty())
joinOp->addAttribute(kUnrollAttrName,
builder.getI64ArrayAttr(unrollFactors));
if (!baseCoords.empty())
joinOp->addAttribute(kBaseCoordAttrName,
builder.getI64ArrayAttr(baseCoords));
return builder.createOperation(*joinOp);
}
// Clones `op` into a new operations that takes `operands` and returns
// `resultTypes`.
static Operation *cloneOpWithOperandsAndTypes(PatternRewriter &builder,
@ -202,9 +150,9 @@ static void getMappedElements(const DenseMap<int64_t, int64_t> &indexMap,
}
}
// UnrolledOperandState aggregates per-operand state required for op unrolling.
struct UnrolledOperandState {
Value *operand;
// UnrolledVectorState aggregates per-operand/result vector state required for
// unrolling.
struct UnrolledVectorState {
SmallVector<int64_t, 4> unrolledShape;
SmallVector<int64_t, 4> unrollFactors;
SmallVector<int64_t, 8> basis;
@ -212,14 +160,12 @@ struct UnrolledOperandState {
};
// Populates 'state' with unrolled shape, unroll factors, basis and
// num unrolled instances for 'operand'.
static void getUnrolledOperandState(Value *operand,
// num unrolled instances for 'vectorType'.
static void initUnrolledVectorState(VectorType vectorType,
const DenseMap<int64_t, int64_t> &indexMap,
ArrayRef<int64_t> targetShape,
UnrolledOperandState &state) {
auto vectorType = operand->getType().cast<VectorType>();
state.operand = operand;
// Compute unrolled shape of 'operand'.
UnrolledVectorState &state) {
// Compute unrolled shape of 'vectorType'.
state.unrolledShape.resize(vectorType.getRank());
getMappedElements(indexMap, targetShape, state.unrolledShape);
// Compute unroll factors for unrolled shape.
@ -233,53 +179,72 @@ static void getUnrolledOperandState(Value *operand,
}
// Computes and returns the linear index of the unrolled vector at
// 'vectorOffsets' within the vector operand represented by 'state'.
// 'vectorOffsets' within the vector represented by 'state'.
static int64_t
getUnrolledOperandLinearIndex(UnrolledOperandState &state,
getUnrolledVectorLinearIndex(UnrolledVectorState &state,
ArrayRef<int64_t> vectorOffsets,
DenseMap<int64_t, int64_t> &indexMap) {
// Compute operand offsets.
// Compute vector offsets.
SmallVector<int64_t, 4> sliceOffsets(state.unrolledShape.size());
getMappedElements(indexMap, vectorOffsets, sliceOffsets);
// Compute and return linear index of 'sliceOffsets' w.r.t 'state.basis'.
return linearize(sliceOffsets, state.basis);
}
// Returns an unrolled vector at 'vectorOffsets' within the vector operand
// represented by 'state'. The value is created if not present in 'cache'.
static Value *getOrCreateUnrolledOperandSlice(
Location loc, UnrolledOperandState &state, ArrayRef<int64_t> vectorOffsets,
// Returns an unrolled vector at 'vectorOffsets' within the vector
// represented by 'state'. The vector is created from a slice of 'initValue'
// if not present in 'cache'.
static Value *getOrCreateUnrolledVectorSlice(
Location loc, UnrolledVectorState &state, ArrayRef<int64_t> vectorOffsets,
ArrayRef<int64_t> offsets, DenseMap<int64_t, int64_t> &indexMap,
SmallVectorImpl<Value *> &cache, PatternRewriter &builder) {
// Compute operand offsets.
Value *initValue, SmallVectorImpl<Value *> &cache,
PatternRewriter &builder) {
// Compute slice offsets.
SmallVector<int64_t, 4> sliceOffsets(state.unrolledShape.size());
getMappedElements(indexMap, offsets, sliceOffsets);
// TODO(b/144845578) Support non-1 strides.
SmallVector<int64_t, 4> sliceStrides(state.unrolledShape.size(), 1);
// Compute linear index of 'sliceOffsets' w.r.t 'state.basis'.
int64_t sliceLinearIndex =
getUnrolledOperandLinearIndex(state, vectorOffsets, indexMap);
getUnrolledVectorLinearIndex(state, vectorOffsets, indexMap);
assert(sliceLinearIndex < static_cast<int64_t>(cache.size()));
auto *operandSlice = cache[sliceLinearIndex];
if (operandSlice == nullptr) {
// Initialize 'cache' with slice from 'state.operand'.
operandSlice = builder.create<vector::StridedSliceOp>(
loc, state.operand, sliceOffsets, state.unrolledShape, sliceStrides);
auto *valueSlice = cache[sliceLinearIndex];
if (valueSlice == nullptr) {
assert(initValue != nullptr);
// Initialize 'cache' with slice from 'state.value'.
valueSlice = builder.create<vector::StridedSliceOp>(
loc, initValue, sliceOffsets, state.unrolledShape, sliceStrides);
// Store value back to 'cache'.
cache[sliceLinearIndex] = operandSlice;
cache[sliceLinearIndex] = valueSlice;
}
return operandSlice;
return valueSlice;
}
// VectorState aggregates per-operand/result vector state required for
// creating slices of vector operands, and clones of the operation being
// unrolled.
struct VectorState {
// The type of this vector.
VectorType type;
// Map from iteration space index to vector dimension index.
DenseMap<int64_t, int64_t> indexMap;
// Index of this value in operation's operand list (-1 if not an operand).
int64_t operandIndex = -1;
// Accumulator iterator flag.
bool isAcc = false;
};
//
// unrollSingleResultStructuredOp
//
// Returns a value representing the result of structured operation 'op'
// with iteration bounds 'iterationBounds' unrolled to 'targetShape'.
// An iteration space index map argument 'iterationIndexMapList' must be
// specified, with a map for each structured op input and a single map for the
// single result. The map at index 'indexMapListResultIndex' in the list must
// be the single result map.
// A list of VectorState objects must be specified in 'vectors', where
// each VectorState in the list represents a vector operand or vector result
// (if the operation does not have an accumulator operand).
// The VectorState at index 'resultIndex' in the list must be the state
// associated with the operations single result (i.e. either its accumulator
// operand or vector result value).
//
// Example:
//
@ -304,12 +269,23 @@ static Value *getOrCreateUnrolledOperandSlice(
// insertslice
// |
// TODO(andydavis) Add the following canonicalization/simplifcation patterns:
// *) Add pattern which matches InsertStridedSlice -> StridedSlice and forwards
// InsertStridedSlice operand to StridedSlice.
// *) Add pattern which matches SourceOp -> StridedSlice -> UserOp which checks
// if there are duplicate identical StridedSlice ops from SourceOp, and
// rewrites itself to use the first duplicate. This transformation should
// cause users of identifical StridedSlice ops to reuse the same StridedSlice
// operation, and leave the duplicate StridedSlice ops with no users
// (removable with DCE).
// TODO(andydavis) Generalize this to support structured ops beyond
// vector ContractionOp, and merge it with 'unrollSingleResultOpMatchingType'
static Value *unrollSingleResultStructuredOp(
Operation *op, ArrayRef<int64_t> iterationBounds,
std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMapList,
unsigned indexMapListResultIndex, ArrayRef<int64_t> targetShape,
static Value *unrollSingleResultStructuredOp(Operation *op,
ArrayRef<int64_t> iterationBounds,
std::vector<VectorState> &vectors,
unsigned resultIndex,
ArrayRef<int64_t> targetShape,
PatternRewriter &builder) {
auto shapedType = op->getResult(0)->getType().dyn_cast_or_null<ShapedType>();
if (!shapedType || !shapedType.hasStaticShape())
@ -321,27 +297,25 @@ static Value *unrollSingleResultStructuredOp(
assert(false && "Failed to compute unroll factors for target shape");
auto unrollFactors = *maybeUnrollFactors;
// Compute unrolled operation state for each mapped operand.
unsigned numMaps = iterationIndexMapList.size();
SmallVector<UnrolledOperandState, 3> unrolledOperandState(numMaps);
assert(op->getNumOperands() >= numMaps);
for (unsigned i = 0; i < numMaps; ++i) {
getUnrolledOperandState(op->getOperand(i), iterationIndexMapList[i],
targetShape, unrolledOperandState[i]);
// Compute unrolled vector state for each vector in 'vectors'.
unsigned numVectors = vectors.size();
SmallVector<UnrolledVectorState, 3> unrolledVectorState(numVectors);
for (unsigned i = 0; i < numVectors; ++i) {
initUnrolledVectorState(vectors[i].type, vectors[i].indexMap, targetShape,
unrolledVectorState[i]);
}
// Compute number of total unrolled instances.
auto numUnrolledInstances = computeMaxLinearIndex(unrollFactors);
auto basis = computeStrides(unrollFactors);
auto &resultOperandState = unrolledOperandState[indexMapListResultIndex];
auto unrolledResultType = VectorType::get(resultOperandState.unrolledShape,
auto &resultValueState = unrolledVectorState[resultIndex];
auto unrolledResultType = VectorType::get(resultValueState.unrolledShape,
shapedType.getElementType());
// Initialize caches for intermediate vector results.
std::vector<SmallVector<Value *, 4>> caches(numMaps);
for (unsigned i = 0; i < numMaps; ++i) {
caches[i].resize(unrolledOperandState[i].numInstances);
}
std::vector<SmallVector<Value *, 4>> caches(numVectors);
for (unsigned i = 0; i < numVectors; ++i)
caches[i].resize(unrolledVectorState[i].numInstances);
// Unroll 'numUnrolledInstances' of 'op', storing results in 'caches'.
for (unsigned i = 0; i < numUnrolledInstances; ++i) {
@ -352,11 +326,15 @@ static Value *unrollSingleResultStructuredOp(
vectorOffsets, targetShape);
// Get cached slice (or create slice) for each operand at 'offsets'.
SmallVector<Value *, 3> operands;
operands.reserve(numMaps);
for (unsigned i = 0; i < numMaps; ++i) {
operands.push_back(getOrCreateUnrolledOperandSlice(
op->getLoc(), unrolledOperandState[i], vectorOffsets, offsets,
iterationIndexMapList[i], caches[i], builder));
operands.resize(op->getNumOperands());
for (unsigned i = 0; i < numVectors; ++i) {
int64_t operandIndex = vectors[i].operandIndex;
if (operandIndex < 0)
continue; // Output
auto *operand = op->getOperand(operandIndex);
operands[operandIndex] = getOrCreateUnrolledVectorSlice(
op->getLoc(), unrolledVectorState[i], vectorOffsets, offsets,
vectors[i].indexMap, operand, caches[i], builder);
}
// Create op on sliced vector arguments.
auto resultVector =
@ -365,146 +343,117 @@ static Value *unrollSingleResultStructuredOp(
->getResult(0);
// Compute linear result index.
int64_t resultIndex = getUnrolledOperandLinearIndex(
resultOperandState, vectorOffsets,
iterationIndexMapList[indexMapListResultIndex]);
// Update result cache at 'resultIndex'.
caches[indexMapListResultIndex][resultIndex] = resultVector;
int64_t linearIndex = getUnrolledVectorLinearIndex(
resultValueState, vectorOffsets, vectors[resultIndex].indexMap);
// Update result cache at 'linearIndex'.
caches[resultIndex][linearIndex] = resultVector;
}
// Make zero splat into which we will insert results from
// 'cache[indexMapListResultIndex]'
// 'cache[resultIndex]'
auto resultVectorType = op->getResult(0)->getType().cast<VectorType>();
auto *res = makeSplatZero(op->getLoc(), builder, resultVectorType);
SmallVector<int64_t, 4> strides(resultOperandState.unrollFactors.size(), 1);
SmallVector<int64_t, 4> strides(resultValueState.unrollFactors.size(), 1);
// Insert vector accumulators into output.
for (unsigned i = 0; i < resultOperandState.numInstances; ++i) {
auto vectorOffsets = delinearize(i, resultOperandState.basis);
for (unsigned i = 0; i < resultValueState.numInstances; ++i) {
auto vectorOffsets = delinearize(i, resultValueState.basis);
// Convert from unrolled vector-space offsets to element-space offsets.
auto offsets = zipMap([](int64_t v1, int64_t v2) { return v1 * v2; },
vectorOffsets, resultOperandState.unrolledShape);
vectorOffsets, resultValueState.unrolledShape);
res = builder.create<vector::InsertStridedSliceOp>(
op->getLoc(), caches[indexMapListResultIndex][i], res, offsets,
strides);
op->getLoc(), caches[resultIndex][i], res, offsets, strides);
}
return res;
}
// Entry point for unrolling declarative pattern rewrites.
// `op` is unrolled to the `targetShape` as follows, for each of its operands:
// 1. the unrolled type `unrolledVectorType` and number of unrolled instances
// `numUnrolledInstances` are computed from the `targetShape`. For now it is
// assumed the unrolling factors divide the vector sizes.
// 2. a fakeFork cast op is inserted that takes the operand and returns
// `numUnrolledInstances` results of type `unrolledVectorType`.
// 3. the original op is cloned `numUnrolledInstances` times, once for each
// result of the fakeFork cast op.
// 4. a fakeJoin cast op takes all these results and merges them into a single
// aggregate vector result whose size matches the original non-unrolled op
// operand types.
//
// Example:
//
// opA(operand0, operand1) // numUnrolledInstances = 3
//
// operand0 operand1
// | |
// fork fork
// <----------gather all fork ops --------->
// /|\ /|\
// f00 f01 f02 f10 f11 f12
// <---------- clone op 3 times --------->
// opA0(f00, f10), opA1(f01, f11), opA2(f02, f12)
// \ | /
// <-------------------- join ------------------------->
//
// Other local patterns then kick in iteratively (including DCE) and compose
// until all the fakeFork and fakeJoin ops are removed.
//
// This will be extended in the future to support more advanced use cases than
// simple pointwise ops.
Value * mlir::vector::unrollSingleResultOpMatchingType(PatternRewriter &builder,
Operation *op,
ArrayRef<int64_t> targetShape) {
if (auto contractionOp = dyn_cast<vector::ContractionOp>(op)) {
static void getVectorContractionOpUnrollState(
vector::ContractionOp contractionOp, ArrayRef<int64_t> targetShape,
SmallVectorImpl<int64_t> &iterationBounds,
std::vector<VectorState> &vectors, unsigned &resultIndex) {
// Get contraction op iteration bounds.
SmallVector<int64_t, 6> iterationBounds;
contractionOp.getIterationBounds(iterationBounds);
assert(iterationBounds.size() == targetShape.size());
// Get map from iteration space index to lhs/rhs/result shape index.
std::vector<DenseMap<int64_t, int64_t>> iterationIndexMapList;
contractionOp.getIterationIndexMap(iterationIndexMapList);
unsigned numIterators = iterationIndexMapList.size();
vectors.resize(numIterators);
unsigned accOperandIndex = vector::ContractionOp::getAccOperandIndex();
for (unsigned i = 0; i < numIterators; ++i) {
vectors[i].type = contractionOp.getOperand(i)->getType().cast<VectorType>();
vectors[i].indexMap = iterationIndexMapList[i];
vectors[i].operandIndex = i;
vectors[i].isAcc = i == accOperandIndex ? true : false;
}
if (llvm::size(contractionOp.masks()) == 2) {
// Add maps for lhs/rhs vector mask arguments (same lhs/rhs vector shape)
iterationIndexMapList.push_back(iterationIndexMapList[0]);
iterationIndexMapList.push_back(iterationIndexMapList[1]);
// Add vectors for lhs/rhs vector mask arguments. Masks have the
// same vector shape lhs/rhs args, so copy their index maps.
vectors.push_back(
{vectors[0].type, vectors[0].indexMap, accOperandIndex + 1, false});
vectors.push_back(
{vectors[1].type, vectors[1].indexMap, accOperandIndex + 2, false});
}
// Unroll 'op' 'iterationBounds' to 'targetShape'.
// TODO(andydavis) Use linalg style 'args_in'/'args_out' to partition
// 'iterationIndexMapList' instead of 'indexMapListResultIndex'.
return unrollSingleResultStructuredOp(
op, iterationBounds, iterationIndexMapList,
/*indexMapListResultIndex=*/2, targetShape, builder);
}
// TODO(andydavis) Create trivial iteration bounds and index map for
// elementwise operations and call 'unrollSingleResultStructuredOp'. Remove
// fakefork/join if possible.
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
"]: unrollSingleResultOpMatchingType on func:\n");
LLVM_DEBUG(op->getParentOfType<FuncOp>().print(dbgs()));
if (!op->getNumResults())
assert(false && "Use precondition till RewriterGen can act on nullptr");
auto shapedType = op->getResult(0)->getType().dyn_cast_or_null<ShapedType>();
if (!shapedType || !shapedType.hasStaticShape())
assert(false && "Use precondition till RewriterGen can act on nullptr");
auto shape = shapedType.getShape();
auto maybeUnrollFactors = shapeRatio(shape, targetShape);
if (!maybeUnrollFactors.hasValue())
assert(false && "Use precondition till RewriterGen can act on nullptr");
auto unrollFactors = *maybeUnrollFactors;
auto loc = op->getLoc();
auto numUnrolledInstances = computeMaxLinearIndex(unrollFactors);
auto unrolledVectorType =
VectorType::get(targetShape, shapedType.getElementType());
SmallVector<Type, 4> forkedType(numUnrolledInstances, unrolledVectorType);
SmallVector<Operation *, 4> forkeds;
forkeds.reserve(numUnrolledInstances);
// Create a new forkOp for each operand.
for (auto *operand : op->getOperands())
forkeds.push_back(
createFakeForkOp(builder, loc, operand, forkedType, unrollFactors));
SmallVector<Operation *, 4> newOps;
newOps.reserve(numUnrolledInstances);
for (int64_t idx = 0; idx < numUnrolledInstances; ++idx) {
SmallVector<Value *, 4> operands;
operands.reserve(forkeds.size());
for (auto *fork : forkeds) {
operands.push_back(fork->getResult(idx));
}
newOps.push_back(cloneOpWithOperandsAndTypes(builder, loc, op, operands,
unrolledVectorType));
}
SmallVector<Value *, 4> newOpResults;
newOpResults.reserve(newOps.size());
for (auto *newOp : newOps)
newOpResults.push_back(newOp->getResult(0));
return createFakeJoinOp(builder, loc, newOpResults, shapedType, unrollFactors,
{0})
->getResult(0);
// 'vectors' instead of 'resultIndex'.
resultIndex = accOperandIndex;
}
// Patterns with this benefit just forwards arguments to clean up fake fork and
// fake joins. It is a nicer and more direct cleanup when we can use it so it
// kicks in with higher precedence.
static constexpr int64_t kMatchingFakeForkFakeJoinBenefit = 1;
static void
getVectorElementwiseOpUnrollState(Operation *op, ArrayRef<int64_t> targetShape,
SmallVectorImpl<int64_t> &iterationBounds,
std::vector<VectorState> &vectors,
unsigned &resultIndex) {
// Verify that operation and operands all have the same vector shape.
auto resultType = op->getResult(0)->getType().dyn_cast_or_null<VectorType>();
assert(resultType && "Expected op with vector result type");
auto resultShape = resultType.getShape();
// Verify that all operands have the same vector type as result.
assert(llvm::all_of(op->getOperandTypes(),
[=](Type type) { return type == resultType; }));
// Populate 'iterationBounds' with 'resultShape' for elementwise operations.
iterationBounds.assign(resultShape.begin(), resultShape.end());
// Create trivial elementwise identity index map based on 'resultShape'.
DenseMap<int64_t, int64_t> indexMap;
indexMap.reserve(resultShape.size());
for (unsigned i = 0; i < resultShape.size(); ++i)
indexMap[i] = i;
// Create VectorState each operand and single result.
unsigned numVectors = op->getNumOperands() + op->getNumResults();
vectors.resize(numVectors);
for (unsigned i = 0; i < op->getNumOperands(); ++i)
vectors[i] = {resultType, indexMap, i, false};
vectors[numVectors - 1] = {resultType, indexMap, -1, false};
resultIndex = numVectors - 1;
}
// Entry point for unrolling declarative pattern rewrites.
Value *mlir::vector::unrollSingleResultOpMatchingType(
PatternRewriter &builder, Operation *op, ArrayRef<int64_t> targetShape) {
assert(op->getNumResults() == 1 && "Expected single result operation");
// Populate 'iterationBounds', 'vectors' and 'resultIndex' to unroll 'op'.
SmallVector<int64_t, 6> iterationBounds;
std::vector<VectorState> vectors;
unsigned resultIndex;
if (auto contractionOp = dyn_cast<vector::ContractionOp>(op)) {
// Popultate state for vector ContractionOp.
getVectorContractionOpUnrollState(contractionOp, targetShape,
iterationBounds, vectors, resultIndex);
} else {
// Populate state for vector elementwise op.
getVectorElementwiseOpUnrollState(op, targetShape, iterationBounds, vectors,
resultIndex);
}
// Unroll 'op' with 'iterationBounds' to 'targetShape'.
return unrollSingleResultStructuredOp(op, iterationBounds, vectors,
resultIndex, targetShape, builder);
}
namespace mlir {
namespace vector {
@ -514,177 +463,9 @@ namespace {
} // end namespace vector
} // end namespace mlir
// Match a fakeFork fed by a fakeJoin and just forward its operands.
// This is akin to calling `replaceAllUsesOf` but made to play nice with all the
// other RewritePattern.
struct ConvertMatchingFakeForkFakeJoinOp : public RewritePattern {
ConvertMatchingFakeForkFakeJoinOp(MLIRContext *context)
// low-benefit to kick-in late
: RewritePattern(kFakeForkOp, kMatchingFakeForkFakeJoinBenefit, context) {
}
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
if (op->getNumOperands() != 1)
return matchFailure();
auto *definingOp = op->getOperand(0)->getDefiningOp();
if (!definingOp || definingOp->getName().getStringRef() != kFakeJoinOp)
return matchFailure();
if (definingOp->getNumOperands() != op->getNumResults())
return matchFailure();
for (auto it : llvm::zip(definingOp->getOperands(), op->getResults())) {
if (std::get<0>(it)->getType() != std::get<1>(it)->getType())
return matchFailure();
}
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
"]: ConvertMatchingFakeForkFakeJoinOp on op: "
<< *op << " in func:\n");
LLVM_DEBUG(op->getParentOfType<FuncOp>().print(dbgs()));
rewriter.replaceOp(op, definingOp->getOperands());
return matchSuccess();
}
};
// Rewrites a fakeFork, whose (unique) operand is a blockArgument, into multiple
// vector.strided_slice ops.
struct ConvertFakeForkFromBlockArgsOrTransferReadOp : public RewritePattern {
ConvertFakeForkFromBlockArgsOrTransferReadOp(MLIRContext *context)
// low-benefit to kick-in late
: RewritePattern(kFakeForkOp, 0, context) {}
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
if (op->getNumOperands() != 1)
return matchFailure();
if (op->use_empty()) {
rewriter.eraseOp(op);
return matchSuccess();
}
auto *arg = op->getOperand(0);
if (!isa<BlockArgument>(arg) &&
!isa<vector::TransferReadOp>(arg->getDefiningOp()))
return matchFailure();
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
"]: ConvertFakeForkFromBlockArgsOp on op: "
<< *op << " in func:\n");
LLVM_DEBUG(op->getParentOfType<FuncOp>().print(dbgs()));
// Look at the unroll factors remaining on this op and act on the first one.
auto unrollFactorsStorage = extractUnrollFactors(op);
ArrayRef<int64_t> unrollFactors{unrollFactorsStorage};
if (unrollFactors.empty()) {
// No more unrollFactors, just sanity check + forward the unique operand.
assert(op->getNumResults() == 1);
assert(arg->getType() == op->getResult(0)->getType());
rewriter.replaceOp(op, arg);
return matchSuccess();
}
// Strides are always 1 for now.
// TODO(b/144845578) support non-1 strides.
auto forkedVectorType = arg->getType().cast<VectorType>();
SmallVector<int64_t, 4> strides(unrollFactors.size(), 1);
auto nUnrolled = computeMaxLinearIndex(unrollFactors);
SmallVector<Value *, 4> extractedVectors;
extractedVectors.reserve(op->getNumResults());
auto linearizationBasis = computeStrides(unrollFactors);
for (unsigned idx = 0; idx < nUnrolled; ++idx) {
auto offsets = delinearize(idx, linearizationBasis);
offsets = zipMap([](int64_t v1, int64_t v2) { return v1 * v2; }, offsets,
unrollFactors);
auto leadingSize =
forkedVectorType.getShape().take_front(unrollFactors.size());
auto sizes = zipMap([](int64_t v1, int64_t v2) { return v1 / v2; },
leadingSize, unrollFactors);
extractedVectors.push_back(
rewriter
.create<vector::StridedSliceOp>(op->getLoc(), arg, offsets, sizes,
strides)
.getResult());
}
rewriter.replaceOp(op, extractedVectors);
return matchSuccess();
}
};
// Rewrites a fakeJoin, whose (unique) operand is a blockArgument, into multiple
// vector.strided_slice ops.
struct ConvertFakeJoinOp : public RewritePattern {
ConvertFakeJoinOp(MLIRContext *context)
// low-benefit to kick-in late
: RewritePattern(kFakeJoinOp, 0, context) {}
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
if (op->getNumResults() != 1)
return matchFailure();
if (op->use_empty()) {
rewriter.eraseOp(op);
return matchSuccess();
}
auto resultVectorType = op->getResult(0)->getType().cast<VectorType>();
auto loc = op->getLoc();
auto *res = makeSplatZero(loc, rewriter, resultVectorType);
auto unrollFactorsStorage = extractUnrollFactors(op);
ArrayRef<int64_t> unrollFactors{unrollFactorsStorage};
auto linearizationBasis = computeStrides(unrollFactors);
auto nUnrolled = computeMaxLinearIndex(unrollFactors);
SmallVector<int64_t, 4> strides(unrollFactors.size(), 1);
for (unsigned idx = 0; idx < nUnrolled; ++idx) {
auto offsets = delinearize(idx, linearizationBasis);
offsets = zipMap([](int64_t v1, int64_t v2) { return v1 * v2; }, offsets,
unrollFactors);
res = rewriter.create<vector::InsertStridedSliceOp>(
loc, op->getOperand(idx), res, offsets, strides);
}
rewriter.replaceOp(op, res);
return matchSuccess();
}
};
// Simple DCE for fakeForkOps/fakeJoinOps, we do not want them to escape a
// transformation (otherwise the transformation is considered incorrect).
struct FakeForkTrait {
static constexpr char const *name = kFakeForkOp;
};
struct FakeJoinTrait {
static constexpr char const *name = kFakeJoinOp;
};
template <typename OpNameTrait> struct DCEPattern : public RewritePattern {
DCEPattern(MLIRContext *context)
// low-benefit to kick-in late
: RewritePattern(OpNameTrait::name, 0, context) {}
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
assert(op->getName().getStringRef() == kFakeForkOp ||
op->getName().getStringRef() == kFakeJoinOp);
if (!op->use_empty())
return matchFailure();
rewriter.eraseOp(op);
return matchSuccess();
}
};
void mlir::populateVectorToVectorConversionPatterns(
MLIRContext *context, OwningRewritePatternList &patterns,
ArrayRef<int64_t> coarseVectorShape, ArrayRef<int64_t> fineVectorShape) {
vector::populateWithGenerated(context, &patterns);
vector::populateVectorToVectorCanonicalizationPatterns(patterns, context);
patterns
.insert<ConvertMatchingFakeForkFakeJoinOp,
ConvertFakeForkFromBlockArgsOrTransferReadOp, ConvertFakeJoinOp,
DCEPattern<FakeForkTrait>, DCEPattern<FakeJoinTrait>>(context);
}