Refactor / improve replaceAllMemRefUsesWith
Refactor replaceAllMemRefUsesWith to split it into two methods: the new method does the replacement on a single op, and is used by the existing one. - make the methods return LogicalResult instead of bool - Earlier, when replacement failed (due to non-deferencing uses of the memref), the set of ops that had already been processed would have been replaced leaving the IR in an inconsistent state. Now, a pass is made over all ops to first check for non-deferencing uses, and then replacement is performed. No test cases were affected because all clients of this method were first checking for non-deferencing uses before calling this method (for other reasons). This isn't true for a use case in another upcoming PR (scalar replacement); clients can now bail out with consistent IR on failure of replaceAllMemRefUsesWith. Add test case. - multiple deferencing uses of the same memref in a single op is possible (we have no such use cases/scenarios), and this has always remained unsupported. Add an assertion for this. - minor fix to another test pipeline-data-transfer case. Signed-off-by: Uday Bondhugula <uday@polymagelabs.com> Closes #87 COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/87 from bondhugula:memref 5153a6194d875eb0424ea8a4ebf1764da8035645 PiperOrigin-RevId: 265808183
This commit is contained in:
parent
bb9ce9acec
commit
63ba081d07
38
third_party/mlir/include/mlir/Transforms/Utils.h
vendored
38
third_party/mlir/include/mlir/Transforms/Utils.h
vendored
@ -37,26 +37,26 @@ class AffineForOp;
|
||||
class Location;
|
||||
class OpBuilder;
|
||||
|
||||
/// Replaces all "deferencing" uses of oldMemRef with newMemRef while optionally
|
||||
/// remapping the old memref's indices using the supplied affine map,
|
||||
/// 'indexRemap'. The new memref could be of a different shape or rank.
|
||||
/// 'extraIndices' provides additional access indices to be added to the start.
|
||||
/// Replaces all "dereferencing" uses of `oldMemRef` with `newMemRef` while
|
||||
/// optionally remapping the old memref's indices using the supplied affine map,
|
||||
/// `indexRemap`. The new memref could be of a different shape or rank.
|
||||
/// `extraIndices` provides additional access indices to be added to the start.
|
||||
///
|
||||
/// 'indexRemap' remaps indices of the old memref access to a new set of indices
|
||||
/// `indexRemap` remaps indices of the old memref access to a new set of indices
|
||||
/// that are used to index the memref. Additional input operands to indexRemap
|
||||
/// can be optionally provided, and they are added at the start of its input
|
||||
/// list. 'indexRemap' is expected to have only dimensional inputs, and the
|
||||
/// list. `indexRemap` is expected to have only dimensional inputs, and the
|
||||
/// number of its inputs equal to extraOperands.size() plus rank of the memref.
|
||||
/// 'extraOperands' is an optional argument that corresponds to additional
|
||||
/// operands (inputs) for indexRemap at the beginning of its input list.
|
||||
///
|
||||
/// 'domInstFilter', if non-null, restricts the replacement to only those
|
||||
/// `domInstFilter`, if non-null, restricts the replacement to only those
|
||||
/// operations that are dominated by the former; similarly, `postDomInstFilter`
|
||||
/// restricts replacement to only those operations that are postdominated by it.
|
||||
///
|
||||
/// Returns true on success and false if the replacement is not possible,
|
||||
/// whenever a memref is used as an operand in a non-deferencing context, except
|
||||
/// for dealloc's on the memref which are left untouched. See comments at
|
||||
/// whenever a memref is used as an operand in a non-dereferencing context,
|
||||
/// except for dealloc's on the memref which are left untouched. See comments at
|
||||
/// function definition for an example.
|
||||
//
|
||||
// Ex: to replace load %A[%i, %j] with load %Abuf[%t mod 2, %ii - %i, %j]:
|
||||
@ -66,12 +66,20 @@ class OpBuilder;
|
||||
// extra operands, note that 'indexRemap' would just be applied to existing
|
||||
// indices (%i, %j).
|
||||
// TODO(bondhugula): allow extraIndices to be added at any position.
|
||||
bool replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
|
||||
ArrayRef<Value *> extraIndices = {},
|
||||
AffineMap indexRemap = AffineMap(),
|
||||
ArrayRef<Value *> extraOperands = {},
|
||||
Operation *domInstFilter = nullptr,
|
||||
Operation *postDomInstFilter = nullptr);
|
||||
LogicalResult replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
|
||||
ArrayRef<Value *> extraIndices = {},
|
||||
AffineMap indexRemap = AffineMap(),
|
||||
ArrayRef<Value *> extraOperands = {},
|
||||
Operation *domInstFilter = nullptr,
|
||||
Operation *postDomInstFilter = nullptr);
|
||||
|
||||
/// Performs the same replacement as the other version above but only for the
|
||||
/// dereferencing uses of `oldMemRef` in `op`.
|
||||
LogicalResult replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
|
||||
Operation *op,
|
||||
ArrayRef<Value *> extraIndices = {},
|
||||
AffineMap indexRemap = AffineMap(),
|
||||
ArrayRef<Value *> extraOperands = {});
|
||||
|
||||
/// Creates and inserts into 'builder' a new AffineApplyOp, with the number of
|
||||
/// its results equal to the number of operands, as a composition
|
||||
|
@ -952,12 +952,13 @@ static Value *createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
|
||||
? AffineMap()
|
||||
: b.getAffineMap(outerIVs.size() + rank, 0, remapExprs);
|
||||
// Replace all users of 'oldMemRef' with 'newMemRef'.
|
||||
bool ret =
|
||||
LogicalResult res =
|
||||
replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap,
|
||||
/*extraOperands=*/outerIVs,
|
||||
/*domInstFilter=*/&*forOp.getBody()->begin());
|
||||
assert(ret && "replaceAllMemrefUsesWith should always succeed here");
|
||||
(void)ret;
|
||||
assert(succeeded(res) &&
|
||||
"replaceAllMemrefUsesWith should always succeed here");
|
||||
(void)res;
|
||||
return newMemRef;
|
||||
}
|
||||
|
||||
|
@ -115,13 +115,14 @@ static bool doubleBuffer(Value *oldMemRef, AffineForOp forOp) {
|
||||
auto ivModTwoOp = bInner.create<AffineApplyOp>(forOp.getLoc(), modTwoMap,
|
||||
forOp.getInductionVar());
|
||||
|
||||
// replaceAllMemRefUsesWith will always succeed unless the forOp body has
|
||||
// non-deferencing uses of the memref (dealloc's are fine though).
|
||||
if (!replaceAllMemRefUsesWith(oldMemRef, newMemRef,
|
||||
/*extraIndices=*/{ivModTwoOp},
|
||||
/*indexRemap=*/AffineMap(),
|
||||
/*extraOperands=*/{},
|
||||
/*domInstFilter=*/&*forOp.getBody()->begin())) {
|
||||
// replaceAllMemRefUsesWith will succeed unless the forOp body has
|
||||
// non-dereferencing uses of the memref (dealloc's are fine though).
|
||||
if (failed(replaceAllMemRefUsesWith(
|
||||
oldMemRef, newMemRef,
|
||||
/*extraIndices=*/{ivModTwoOp},
|
||||
/*indexRemap=*/AffineMap(),
|
||||
/*extraOperands=*/{},
|
||||
/*domInstFilter=*/&*forOp.getBody()->begin()))) {
|
||||
LLVM_DEBUG(
|
||||
forOp.emitError("memref replacement for double buffering failed"));
|
||||
ivModTwoOp.erase();
|
||||
@ -276,9 +277,9 @@ void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) {
|
||||
if (!doubleBuffer(oldMemRef, forOp)) {
|
||||
// Normally, double buffering should not fail because we already checked
|
||||
// that there are no uses outside.
|
||||
LLVM_DEBUG(llvm::dbgs() << "double buffering failed for: \n";);
|
||||
LLVM_DEBUG(dmaStartInst->dump());
|
||||
// IR still in a valid state.
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "double buffering failed for" << dmaStartInst << "\n";);
|
||||
// IR still valid and semantically correct.
|
||||
return;
|
||||
}
|
||||
// If the old memref has no more uses, remove its 'dead' alloc if it was
|
||||
|
357
third_party/mlir/lib/Transforms/Utils/Utils.cpp
vendored
357
third_party/mlir/lib/Transforms/Utils/Utils.cpp
vendored
@ -57,16 +57,181 @@ static NamedAttribute getAffineMapAttrForMemRef(Operation *op, Value *memref) {
|
||||
return cast<AffineDmaWaitOp>(op).getAffineMapAttrForMemRef(memref);
|
||||
}
|
||||
|
||||
bool mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
|
||||
ArrayRef<Value *> extraIndices,
|
||||
AffineMap indexRemap,
|
||||
ArrayRef<Value *> extraOperands,
|
||||
Operation *domInstFilter,
|
||||
Operation *postDomInstFilter) {
|
||||
// Perform the replacement in `op`.
|
||||
LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
|
||||
Operation *op,
|
||||
ArrayRef<Value *> extraIndices,
|
||||
AffineMap indexRemap,
|
||||
ArrayRef<Value *> extraOperands) {
|
||||
unsigned newMemRefRank = newMemRef->getType().cast<MemRefType>().getRank();
|
||||
(void)newMemRefRank; // unused in opt mode
|
||||
unsigned oldMemRefRank = oldMemRef->getType().cast<MemRefType>().getRank();
|
||||
(void)newMemRefRank;
|
||||
(void)oldMemRefRank;
|
||||
if (indexRemap) {
|
||||
assert(indexRemap.getNumSymbols() == 0 && "pure dimensional map expected");
|
||||
assert(indexRemap.getNumInputs() == extraOperands.size() + oldMemRefRank);
|
||||
assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank);
|
||||
} else {
|
||||
assert(oldMemRefRank + extraIndices.size() == newMemRefRank);
|
||||
}
|
||||
|
||||
// Assert same elemental type.
|
||||
assert(oldMemRef->getType().cast<MemRefType>().getElementType() ==
|
||||
newMemRef->getType().cast<MemRefType>().getElementType());
|
||||
|
||||
if (!isMemRefDereferencingOp(*op))
|
||||
// Failure: memref used in a non-dereferencing context (potentially
|
||||
// escapes); no replacement in these cases.
|
||||
return failure();
|
||||
|
||||
SmallVector<unsigned, 2> usePositions;
|
||||
for (const auto &opEntry : llvm::enumerate(op->getOperands())) {
|
||||
if (opEntry.value() == oldMemRef)
|
||||
usePositions.push_back(opEntry.index());
|
||||
}
|
||||
|
||||
// If memref doesn't appear, nothing to do.
|
||||
if (usePositions.empty())
|
||||
return success();
|
||||
|
||||
if (usePositions.size() > 1) {
|
||||
// TODO(mlir-team): extend it for this case when needed (rare).
|
||||
assert(false && "multiple dereferencing uses in a single op not supported");
|
||||
return failure();
|
||||
}
|
||||
|
||||
unsigned memRefOperandPos = usePositions.front();
|
||||
|
||||
OpBuilder builder(op);
|
||||
NamedAttribute oldMapAttrPair = getAffineMapAttrForMemRef(op, oldMemRef);
|
||||
AffineMap oldMap = oldMapAttrPair.second.cast<AffineMapAttr>().getValue();
|
||||
unsigned oldMapNumInputs = oldMap.getNumInputs();
|
||||
SmallVector<Value *, 4> oldMapOperands(
|
||||
op->operand_begin() + memRefOperandPos + 1,
|
||||
op->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs);
|
||||
|
||||
// Apply 'oldMemRefOperands = oldMap(oldMapOperands)'.
|
||||
SmallVector<Value *, 4> oldMemRefOperands;
|
||||
SmallVector<Value *, 4> affineApplyOps;
|
||||
oldMemRefOperands.reserve(oldMemRefRank);
|
||||
if (oldMap != builder.getMultiDimIdentityMap(oldMap.getNumDims())) {
|
||||
for (auto resultExpr : oldMap.getResults()) {
|
||||
auto singleResMap = builder.getAffineMap(
|
||||
oldMap.getNumDims(), oldMap.getNumSymbols(), resultExpr);
|
||||
auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
|
||||
oldMapOperands);
|
||||
oldMemRefOperands.push_back(afOp);
|
||||
affineApplyOps.push_back(afOp);
|
||||
}
|
||||
} else {
|
||||
oldMemRefOperands.append(oldMapOperands.begin(), oldMapOperands.end());
|
||||
}
|
||||
|
||||
// Construct new indices as a remap of the old ones if a remapping has been
|
||||
// provided. The indices of a memref come right after it, i.e.,
|
||||
// at position memRefOperandPos + 1.
|
||||
SmallVector<Value *, 4> remapOperands;
|
||||
remapOperands.reserve(extraOperands.size() + oldMemRefRank);
|
||||
remapOperands.append(extraOperands.begin(), extraOperands.end());
|
||||
remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end());
|
||||
|
||||
SmallVector<Value *, 4> remapOutputs;
|
||||
remapOutputs.reserve(oldMemRefRank);
|
||||
|
||||
if (indexRemap &&
|
||||
indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) {
|
||||
// Remapped indices.
|
||||
for (auto resultExpr : indexRemap.getResults()) {
|
||||
auto singleResMap = builder.getAffineMap(
|
||||
indexRemap.getNumDims(), indexRemap.getNumSymbols(), resultExpr);
|
||||
auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
|
||||
remapOperands);
|
||||
remapOutputs.push_back(afOp);
|
||||
affineApplyOps.push_back(afOp);
|
||||
}
|
||||
} else {
|
||||
// No remapping specified.
|
||||
remapOutputs.append(remapOperands.begin(), remapOperands.end());
|
||||
}
|
||||
|
||||
SmallVector<Value *, 4> newMapOperands;
|
||||
newMapOperands.reserve(newMemRefRank);
|
||||
|
||||
// Prepend 'extraIndices' in 'newMapOperands'.
|
||||
for (auto *extraIndex : extraIndices) {
|
||||
assert(extraIndex->getDefiningOp()->getNumResults() == 1 &&
|
||||
"single result op's expected to generate these indices");
|
||||
assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) &&
|
||||
"invalid memory op index");
|
||||
newMapOperands.push_back(extraIndex);
|
||||
}
|
||||
|
||||
// Append 'remapOutputs' to 'newMapOperands'.
|
||||
newMapOperands.append(remapOutputs.begin(), remapOutputs.end());
|
||||
|
||||
// Create new fully composed AffineMap for new op to be created.
|
||||
assert(newMapOperands.size() == newMemRefRank);
|
||||
auto newMap = builder.getMultiDimIdentityMap(newMemRefRank);
|
||||
// TODO(b/136262594) Avoid creating/deleting temporary AffineApplyOps here.
|
||||
fullyComposeAffineMapAndOperands(&newMap, &newMapOperands);
|
||||
newMap = simplifyAffineMap(newMap);
|
||||
canonicalizeMapAndOperands(&newMap, &newMapOperands);
|
||||
// Remove any affine.apply's that became dead as a result of composition.
|
||||
for (auto *value : affineApplyOps)
|
||||
if (value->use_empty())
|
||||
value->getDefiningOp()->erase();
|
||||
|
||||
// Construct the new operation using this memref.
|
||||
OperationState state(op->getLoc(), op->getName());
|
||||
state.setOperandListToResizable(op->hasResizableOperandsList());
|
||||
state.operands.reserve(op->getNumOperands() + extraIndices.size());
|
||||
// Insert the non-memref operands.
|
||||
state.operands.append(op->operand_begin(),
|
||||
op->operand_begin() + memRefOperandPos);
|
||||
// Insert the new memref value.
|
||||
state.operands.push_back(newMemRef);
|
||||
|
||||
// Insert the new memref map operands.
|
||||
state.operands.append(newMapOperands.begin(), newMapOperands.end());
|
||||
|
||||
// Insert the remaining operands unmodified.
|
||||
state.operands.append(op->operand_begin() + memRefOperandPos + 1 +
|
||||
oldMapNumInputs,
|
||||
op->operand_end());
|
||||
|
||||
// Result types don't change. Both memref's are of the same elemental type.
|
||||
state.types.reserve(op->getNumResults());
|
||||
for (auto *result : op->getResults())
|
||||
state.types.push_back(result->getType());
|
||||
|
||||
// Add attribute for 'newMap', other Attributes do not change.
|
||||
auto newMapAttr = builder.getAffineMapAttr(newMap);
|
||||
for (auto namedAttr : op->getAttrs()) {
|
||||
if (namedAttr.first == oldMapAttrPair.first) {
|
||||
state.attributes.push_back({namedAttr.first, newMapAttr});
|
||||
} else {
|
||||
state.attributes.push_back(namedAttr);
|
||||
}
|
||||
}
|
||||
|
||||
// Create the new operation.
|
||||
auto *repOp = builder.createOperation(state);
|
||||
op->replaceAllUsesWith(repOp);
|
||||
op->erase();
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
|
||||
ArrayRef<Value *> extraIndices,
|
||||
AffineMap indexRemap,
|
||||
ArrayRef<Value *> extraOperands,
|
||||
Operation *domInstFilter,
|
||||
Operation *postDomInstFilter) {
|
||||
unsigned newMemRefRank = newMemRef->getType().cast<MemRefType>().getRank();
|
||||
(void)newMemRefRank; // unused in opt mode
|
||||
unsigned oldMemRefRank = oldMemRef->getType().cast<MemRefType>().getRank();
|
||||
(void)oldMemRefRank;
|
||||
if (indexRemap) {
|
||||
assert(indexRemap.getNumSymbols() == 0 && "pure dimensional map expected");
|
||||
assert(indexRemap.getNumInputs() == extraOperands.size() + oldMemRefRank);
|
||||
@ -89,170 +254,44 @@ bool mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
|
||||
postDomInfo = std::make_unique<PostDominanceInfo>(
|
||||
postDomInstFilter->getParentOfType<FuncOp>());
|
||||
|
||||
// The ops where memref replacement succeeds are replaced with new ones.
|
||||
SmallVector<Operation *, 8> opsToErase;
|
||||
|
||||
// Walk all uses of old memref. Operation using the memref gets replaced.
|
||||
for (auto *opInst : llvm::make_early_inc_range(oldMemRef->getUsers())) {
|
||||
// Walk all uses of old memref; collect ops to perform replacement. We use a
|
||||
// DenseSet since an operation could potentially have multiple uses of a
|
||||
// memref (although rare), and the replacement later is going to erase ops.
|
||||
DenseSet<Operation *> opsToReplace;
|
||||
for (auto *op : oldMemRef->getUsers()) {
|
||||
// Skip this use if it's not dominated by domInstFilter.
|
||||
if (domInstFilter && !domInfo->dominates(domInstFilter, opInst))
|
||||
if (domInstFilter && !domInfo->dominates(domInstFilter, op))
|
||||
continue;
|
||||
|
||||
// Skip this use if it's not post-dominated by postDomInstFilter.
|
||||
if (postDomInstFilter &&
|
||||
!postDomInfo->postDominates(postDomInstFilter, opInst))
|
||||
if (postDomInstFilter && !postDomInfo->postDominates(postDomInstFilter, op))
|
||||
continue;
|
||||
|
||||
// Skip dealloc's - no replacement is necessary, and a replacement doesn't
|
||||
// hurt dealloc's.
|
||||
if (isa<DeallocOp>(opInst))
|
||||
// Skip dealloc's - no replacement is necessary, and a memref replacement
|
||||
// at other uses doesn't hurt these dealloc's.
|
||||
if (isa<DeallocOp>(op))
|
||||
continue;
|
||||
|
||||
// Check if the memref was used in a non-deferencing context. It is fine for
|
||||
// the memref to be used in a non-deferencing way outside of the region
|
||||
// where this replacement is happening.
|
||||
if (!isMemRefDereferencingOp(*opInst))
|
||||
// Failure: memref used in a non-deferencing op (potentially escapes); no
|
||||
// replacement in these cases.
|
||||
return false;
|
||||
// Check if the memref was used in a non-dereferencing context. It is fine
|
||||
// for the memref to be used in a non-dereferencing way outside of the
|
||||
// region where this replacement is happening.
|
||||
if (!isMemRefDereferencingOp(*op))
|
||||
// Failure: memref used in a non-dereferencing op (potentially escapes);
|
||||
// no replacement in these cases.
|
||||
return failure();
|
||||
|
||||
auto getMemRefOperandPos = [&]() -> unsigned {
|
||||
unsigned i, e;
|
||||
for (i = 0, e = opInst->getNumOperands(); i < e; i++) {
|
||||
if (opInst->getOperand(i) == oldMemRef)
|
||||
break;
|
||||
}
|
||||
assert(i < opInst->getNumOperands() && "operand guaranteed to be found");
|
||||
return i;
|
||||
};
|
||||
|
||||
OpBuilder builder(opInst);
|
||||
unsigned memRefOperandPos = getMemRefOperandPos();
|
||||
NamedAttribute oldMapAttrPair =
|
||||
getAffineMapAttrForMemRef(opInst, oldMemRef);
|
||||
AffineMap oldMap = oldMapAttrPair.second.cast<AffineMapAttr>().getValue();
|
||||
unsigned oldMapNumInputs = oldMap.getNumInputs();
|
||||
SmallVector<Value *, 4> oldMapOperands(
|
||||
opInst->operand_begin() + memRefOperandPos + 1,
|
||||
opInst->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs);
|
||||
SmallVector<Value *, 4> affineApplyOps;
|
||||
|
||||
// Apply 'oldMemRefOperands = oldMap(oldMapOperands)'.
|
||||
SmallVector<Value *, 4> oldMemRefOperands;
|
||||
oldMemRefOperands.reserve(oldMemRefRank);
|
||||
if (oldMap != builder.getMultiDimIdentityMap(oldMap.getNumDims())) {
|
||||
for (auto resultExpr : oldMap.getResults()) {
|
||||
auto singleResMap = builder.getAffineMap(
|
||||
oldMap.getNumDims(), oldMap.getNumSymbols(), resultExpr);
|
||||
auto afOp = builder.create<AffineApplyOp>(opInst->getLoc(),
|
||||
singleResMap, oldMapOperands);
|
||||
oldMemRefOperands.push_back(afOp);
|
||||
affineApplyOps.push_back(afOp);
|
||||
}
|
||||
} else {
|
||||
oldMemRefOperands.append(oldMapOperands.begin(), oldMapOperands.end());
|
||||
}
|
||||
|
||||
// Construct new indices as a remap of the old ones if a remapping has been
|
||||
// provided. The indices of a memref come right after it, i.e.,
|
||||
// at position memRefOperandPos + 1.
|
||||
SmallVector<Value *, 4> remapOperands;
|
||||
remapOperands.reserve(extraOperands.size() + oldMemRefRank);
|
||||
remapOperands.append(extraOperands.begin(), extraOperands.end());
|
||||
remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end());
|
||||
|
||||
SmallVector<Value *, 4> remapOutputs;
|
||||
remapOutputs.reserve(oldMemRefRank);
|
||||
|
||||
if (indexRemap &&
|
||||
indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) {
|
||||
// Remapped indices.
|
||||
for (auto resultExpr : indexRemap.getResults()) {
|
||||
auto singleResMap = builder.getAffineMap(
|
||||
indexRemap.getNumDims(), indexRemap.getNumSymbols(), resultExpr);
|
||||
auto afOp = builder.create<AffineApplyOp>(opInst->getLoc(),
|
||||
singleResMap, remapOperands);
|
||||
remapOutputs.push_back(afOp);
|
||||
affineApplyOps.push_back(afOp);
|
||||
}
|
||||
} else {
|
||||
// No remapping specified.
|
||||
remapOutputs.append(remapOperands.begin(), remapOperands.end());
|
||||
}
|
||||
|
||||
SmallVector<Value *, 4> newMapOperands;
|
||||
newMapOperands.reserve(newMemRefRank);
|
||||
|
||||
// Prepend 'extraIndices' in 'newMapOperands'.
|
||||
for (auto *extraIndex : extraIndices) {
|
||||
assert(extraIndex->getDefiningOp()->getNumResults() == 1 &&
|
||||
"single result op's expected to generate these indices");
|
||||
assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) &&
|
||||
"invalid memory op index");
|
||||
newMapOperands.push_back(extraIndex);
|
||||
}
|
||||
|
||||
// Append 'remapOutputs' to 'newMapOperands'.
|
||||
newMapOperands.append(remapOutputs.begin(), remapOutputs.end());
|
||||
|
||||
// Create new fully composed AffineMap for new op to be created.
|
||||
assert(newMapOperands.size() == newMemRefRank);
|
||||
auto newMap = builder.getMultiDimIdentityMap(newMemRefRank);
|
||||
// TODO(b/136262594) Avoid creating/deleting temporary AffineApplyOps here.
|
||||
fullyComposeAffineMapAndOperands(&newMap, &newMapOperands);
|
||||
newMap = simplifyAffineMap(newMap);
|
||||
canonicalizeMapAndOperands(&newMap, &newMapOperands);
|
||||
// Remove any affine.apply's that became dead as a result of composition.
|
||||
for (auto *value : affineApplyOps)
|
||||
if (value->use_empty())
|
||||
value->getDefiningOp()->erase();
|
||||
|
||||
// Construct the new operation using this memref.
|
||||
OperationState state(opInst->getLoc(), opInst->getName());
|
||||
state.setOperandListToResizable(opInst->hasResizableOperandsList());
|
||||
state.operands.reserve(opInst->getNumOperands() + extraIndices.size());
|
||||
// Insert the non-memref operands.
|
||||
state.operands.append(opInst->operand_begin(),
|
||||
opInst->operand_begin() + memRefOperandPos);
|
||||
// Insert the new memref value.
|
||||
state.operands.push_back(newMemRef);
|
||||
|
||||
// Insert the new memref map operands.
|
||||
state.operands.append(newMapOperands.begin(), newMapOperands.end());
|
||||
|
||||
// Insert the remaining operands unmodified.
|
||||
state.operands.append(opInst->operand_begin() + memRefOperandPos + 1 +
|
||||
oldMapNumInputs,
|
||||
opInst->operand_end());
|
||||
|
||||
// Result types don't change. Both memref's are of the same elemental type.
|
||||
state.types.reserve(opInst->getNumResults());
|
||||
for (auto *result : opInst->getResults())
|
||||
state.types.push_back(result->getType());
|
||||
|
||||
// Add attribute for 'newMap', other Attributes do not change.
|
||||
auto newMapAttr = builder.getAffineMapAttr(newMap);
|
||||
for (auto namedAttr : opInst->getAttrs()) {
|
||||
if (namedAttr.first == oldMapAttrPair.first) {
|
||||
state.attributes.push_back({namedAttr.first, newMapAttr});
|
||||
} else {
|
||||
state.attributes.push_back(namedAttr);
|
||||
}
|
||||
}
|
||||
|
||||
// Create the new operation.
|
||||
auto *repOp = builder.createOperation(state);
|
||||
opInst->replaceAllUsesWith(repOp);
|
||||
|
||||
// Collect and erase at the end since one of these op's could be
|
||||
// domInstFilter or postDomInstFilter as well!
|
||||
opsToErase.push_back(opInst);
|
||||
// We'll first collect and then replace --- since replacement erases the op
|
||||
// that has the use, and that op could be postDomFilter or domFilter itself!
|
||||
opsToReplace.insert(op);
|
||||
}
|
||||
|
||||
for (auto *opInst : opsToErase)
|
||||
opInst->erase();
|
||||
for (auto *op : opsToReplace) {
|
||||
if (failed(replaceAllMemRefUsesWith(oldMemRef, newMemRef, op, extraIndices,
|
||||
indexRemap, extraOperands)))
|
||||
assert(false && "memref replacement guaranteed to succeed here");
|
||||
}
|
||||
|
||||
return true;
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Given an operation, inserts one or more single result affine
|
||||
|
Loading…
x
Reference in New Issue
Block a user