Refactor / improve replaceAllMemRefUsesWith

Refactor replaceAllMemRefUsesWith to split it into two methods: the new
method does the replacement on a single op, and is used by the existing
one.

- make the methods return LogicalResult instead of bool

- Earlier, when replacement failed (due to non-deferencing uses of the
  memref), the set of ops that had already been processed would have
  been replaced leaving the IR in an inconsistent state. Now, a
  pass is made over all ops to first check for non-deferencing
  uses, and then replacement is performed. No test cases were affected
  because all clients of this method were first checking for
  non-deferencing uses before calling this method (for other reasons).
  This isn't true for a use case in another upcoming PR (scalar
  replacement); clients can now bail out with consistent IR on failure
  of replaceAllMemRefUsesWith. Add test case.

- multiple deferencing uses of the same memref in a single op is
  possible (we have no such use cases/scenarios), and this has always
  remained unsupported. Add an assertion for this.

- minor fix to another test pipeline-data-transfer case.

Signed-off-by: Uday Bondhugula <uday@polymagelabs.com>

Closes #87

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/87 from bondhugula:memref 5153a6194d875eb0424ea8a4ebf1764da8035645
PiperOrigin-RevId: 265808183
This commit is contained in:
Uday Bondhugula 2019-08-27 17:56:25 -07:00 committed by TensorFlower Gardener
parent bb9ce9acec
commit 63ba081d07
4 changed files with 236 additions and 187 deletions

View File

@ -37,26 +37,26 @@ class AffineForOp;
class Location;
class OpBuilder;
/// Replaces all "deferencing" uses of oldMemRef with newMemRef while optionally
/// remapping the old memref's indices using the supplied affine map,
/// 'indexRemap'. The new memref could be of a different shape or rank.
/// 'extraIndices' provides additional access indices to be added to the start.
/// Replaces all "dereferencing" uses of `oldMemRef` with `newMemRef` while
/// optionally remapping the old memref's indices using the supplied affine map,
/// `indexRemap`. The new memref could be of a different shape or rank.
/// `extraIndices` provides additional access indices to be added to the start.
///
/// 'indexRemap' remaps indices of the old memref access to a new set of indices
/// `indexRemap` remaps indices of the old memref access to a new set of indices
/// that are used to index the memref. Additional input operands to indexRemap
/// can be optionally provided, and they are added at the start of its input
/// list. 'indexRemap' is expected to have only dimensional inputs, and the
/// list. `indexRemap` is expected to have only dimensional inputs, and the
/// number of its inputs equal to extraOperands.size() plus rank of the memref.
/// 'extraOperands' is an optional argument that corresponds to additional
/// operands (inputs) for indexRemap at the beginning of its input list.
///
/// 'domInstFilter', if non-null, restricts the replacement to only those
/// `domInstFilter`, if non-null, restricts the replacement to only those
/// operations that are dominated by the former; similarly, `postDomInstFilter`
/// restricts replacement to only those operations that are postdominated by it.
///
/// Returns true on success and false if the replacement is not possible,
/// whenever a memref is used as an operand in a non-deferencing context, except
/// for dealloc's on the memref which are left untouched. See comments at
/// whenever a memref is used as an operand in a non-dereferencing context,
/// except for dealloc's on the memref which are left untouched. See comments at
/// function definition for an example.
//
// Ex: to replace load %A[%i, %j] with load %Abuf[%t mod 2, %ii - %i, %j]:
@ -66,12 +66,20 @@ class OpBuilder;
// extra operands, note that 'indexRemap' would just be applied to existing
// indices (%i, %j).
// TODO(bondhugula): allow extraIndices to be added at any position.
bool replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
ArrayRef<Value *> extraIndices = {},
AffineMap indexRemap = AffineMap(),
ArrayRef<Value *> extraOperands = {},
Operation *domInstFilter = nullptr,
Operation *postDomInstFilter = nullptr);
LogicalResult replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
ArrayRef<Value *> extraIndices = {},
AffineMap indexRemap = AffineMap(),
ArrayRef<Value *> extraOperands = {},
Operation *domInstFilter = nullptr,
Operation *postDomInstFilter = nullptr);
/// Performs the same replacement as the other version above but only for the
/// dereferencing uses of `oldMemRef` in `op`.
LogicalResult replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
Operation *op,
ArrayRef<Value *> extraIndices = {},
AffineMap indexRemap = AffineMap(),
ArrayRef<Value *> extraOperands = {});
/// Creates and inserts into 'builder' a new AffineApplyOp, with the number of
/// its results equal to the number of operands, as a composition

View File

@ -952,12 +952,13 @@ static Value *createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
? AffineMap()
: b.getAffineMap(outerIVs.size() + rank, 0, remapExprs);
// Replace all users of 'oldMemRef' with 'newMemRef'.
bool ret =
LogicalResult res =
replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap,
/*extraOperands=*/outerIVs,
/*domInstFilter=*/&*forOp.getBody()->begin());
assert(ret && "replaceAllMemrefUsesWith should always succeed here");
(void)ret;
assert(succeeded(res) &&
"replaceAllMemrefUsesWith should always succeed here");
(void)res;
return newMemRef;
}

View File

@ -115,13 +115,14 @@ static bool doubleBuffer(Value *oldMemRef, AffineForOp forOp) {
auto ivModTwoOp = bInner.create<AffineApplyOp>(forOp.getLoc(), modTwoMap,
forOp.getInductionVar());
// replaceAllMemRefUsesWith will always succeed unless the forOp body has
// non-deferencing uses of the memref (dealloc's are fine though).
if (!replaceAllMemRefUsesWith(oldMemRef, newMemRef,
/*extraIndices=*/{ivModTwoOp},
/*indexRemap=*/AffineMap(),
/*extraOperands=*/{},
/*domInstFilter=*/&*forOp.getBody()->begin())) {
// replaceAllMemRefUsesWith will succeed unless the forOp body has
// non-dereferencing uses of the memref (dealloc's are fine though).
if (failed(replaceAllMemRefUsesWith(
oldMemRef, newMemRef,
/*extraIndices=*/{ivModTwoOp},
/*indexRemap=*/AffineMap(),
/*extraOperands=*/{},
/*domInstFilter=*/&*forOp.getBody()->begin()))) {
LLVM_DEBUG(
forOp.emitError("memref replacement for double buffering failed"));
ivModTwoOp.erase();
@ -276,9 +277,9 @@ void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) {
if (!doubleBuffer(oldMemRef, forOp)) {
// Normally, double buffering should not fail because we already checked
// that there are no uses outside.
LLVM_DEBUG(llvm::dbgs() << "double buffering failed for: \n";);
LLVM_DEBUG(dmaStartInst->dump());
// IR still in a valid state.
LLVM_DEBUG(llvm::dbgs()
<< "double buffering failed for" << dmaStartInst << "\n";);
// IR still valid and semantically correct.
return;
}
// If the old memref has no more uses, remove its 'dead' alloc if it was

View File

@ -57,16 +57,181 @@ static NamedAttribute getAffineMapAttrForMemRef(Operation *op, Value *memref) {
return cast<AffineDmaWaitOp>(op).getAffineMapAttrForMemRef(memref);
}
bool mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
ArrayRef<Value *> extraIndices,
AffineMap indexRemap,
ArrayRef<Value *> extraOperands,
Operation *domInstFilter,
Operation *postDomInstFilter) {
// Perform the replacement in `op`.
LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
Operation *op,
ArrayRef<Value *> extraIndices,
AffineMap indexRemap,
ArrayRef<Value *> extraOperands) {
unsigned newMemRefRank = newMemRef->getType().cast<MemRefType>().getRank();
(void)newMemRefRank; // unused in opt mode
unsigned oldMemRefRank = oldMemRef->getType().cast<MemRefType>().getRank();
(void)newMemRefRank;
(void)oldMemRefRank;
if (indexRemap) {
assert(indexRemap.getNumSymbols() == 0 && "pure dimensional map expected");
assert(indexRemap.getNumInputs() == extraOperands.size() + oldMemRefRank);
assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank);
} else {
assert(oldMemRefRank + extraIndices.size() == newMemRefRank);
}
// Assert same elemental type.
assert(oldMemRef->getType().cast<MemRefType>().getElementType() ==
newMemRef->getType().cast<MemRefType>().getElementType());
if (!isMemRefDereferencingOp(*op))
// Failure: memref used in a non-dereferencing context (potentially
// escapes); no replacement in these cases.
return failure();
SmallVector<unsigned, 2> usePositions;
for (const auto &opEntry : llvm::enumerate(op->getOperands())) {
if (opEntry.value() == oldMemRef)
usePositions.push_back(opEntry.index());
}
// If memref doesn't appear, nothing to do.
if (usePositions.empty())
return success();
if (usePositions.size() > 1) {
// TODO(mlir-team): extend it for this case when needed (rare).
assert(false && "multiple dereferencing uses in a single op not supported");
return failure();
}
unsigned memRefOperandPos = usePositions.front();
OpBuilder builder(op);
NamedAttribute oldMapAttrPair = getAffineMapAttrForMemRef(op, oldMemRef);
AffineMap oldMap = oldMapAttrPair.second.cast<AffineMapAttr>().getValue();
unsigned oldMapNumInputs = oldMap.getNumInputs();
SmallVector<Value *, 4> oldMapOperands(
op->operand_begin() + memRefOperandPos + 1,
op->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs);
// Apply 'oldMemRefOperands = oldMap(oldMapOperands)'.
SmallVector<Value *, 4> oldMemRefOperands;
SmallVector<Value *, 4> affineApplyOps;
oldMemRefOperands.reserve(oldMemRefRank);
if (oldMap != builder.getMultiDimIdentityMap(oldMap.getNumDims())) {
for (auto resultExpr : oldMap.getResults()) {
auto singleResMap = builder.getAffineMap(
oldMap.getNumDims(), oldMap.getNumSymbols(), resultExpr);
auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
oldMapOperands);
oldMemRefOperands.push_back(afOp);
affineApplyOps.push_back(afOp);
}
} else {
oldMemRefOperands.append(oldMapOperands.begin(), oldMapOperands.end());
}
// Construct new indices as a remap of the old ones if a remapping has been
// provided. The indices of a memref come right after it, i.e.,
// at position memRefOperandPos + 1.
SmallVector<Value *, 4> remapOperands;
remapOperands.reserve(extraOperands.size() + oldMemRefRank);
remapOperands.append(extraOperands.begin(), extraOperands.end());
remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end());
SmallVector<Value *, 4> remapOutputs;
remapOutputs.reserve(oldMemRefRank);
if (indexRemap &&
indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) {
// Remapped indices.
for (auto resultExpr : indexRemap.getResults()) {
auto singleResMap = builder.getAffineMap(
indexRemap.getNumDims(), indexRemap.getNumSymbols(), resultExpr);
auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
remapOperands);
remapOutputs.push_back(afOp);
affineApplyOps.push_back(afOp);
}
} else {
// No remapping specified.
remapOutputs.append(remapOperands.begin(), remapOperands.end());
}
SmallVector<Value *, 4> newMapOperands;
newMapOperands.reserve(newMemRefRank);
// Prepend 'extraIndices' in 'newMapOperands'.
for (auto *extraIndex : extraIndices) {
assert(extraIndex->getDefiningOp()->getNumResults() == 1 &&
"single result op's expected to generate these indices");
assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) &&
"invalid memory op index");
newMapOperands.push_back(extraIndex);
}
// Append 'remapOutputs' to 'newMapOperands'.
newMapOperands.append(remapOutputs.begin(), remapOutputs.end());
// Create new fully composed AffineMap for new op to be created.
assert(newMapOperands.size() == newMemRefRank);
auto newMap = builder.getMultiDimIdentityMap(newMemRefRank);
// TODO(b/136262594) Avoid creating/deleting temporary AffineApplyOps here.
fullyComposeAffineMapAndOperands(&newMap, &newMapOperands);
newMap = simplifyAffineMap(newMap);
canonicalizeMapAndOperands(&newMap, &newMapOperands);
// Remove any affine.apply's that became dead as a result of composition.
for (auto *value : affineApplyOps)
if (value->use_empty())
value->getDefiningOp()->erase();
// Construct the new operation using this memref.
OperationState state(op->getLoc(), op->getName());
state.setOperandListToResizable(op->hasResizableOperandsList());
state.operands.reserve(op->getNumOperands() + extraIndices.size());
// Insert the non-memref operands.
state.operands.append(op->operand_begin(),
op->operand_begin() + memRefOperandPos);
// Insert the new memref value.
state.operands.push_back(newMemRef);
// Insert the new memref map operands.
state.operands.append(newMapOperands.begin(), newMapOperands.end());
// Insert the remaining operands unmodified.
state.operands.append(op->operand_begin() + memRefOperandPos + 1 +
oldMapNumInputs,
op->operand_end());
// Result types don't change. Both memref's are of the same elemental type.
state.types.reserve(op->getNumResults());
for (auto *result : op->getResults())
state.types.push_back(result->getType());
// Add attribute for 'newMap', other Attributes do not change.
auto newMapAttr = builder.getAffineMapAttr(newMap);
for (auto namedAttr : op->getAttrs()) {
if (namedAttr.first == oldMapAttrPair.first) {
state.attributes.push_back({namedAttr.first, newMapAttr});
} else {
state.attributes.push_back(namedAttr);
}
}
// Create the new operation.
auto *repOp = builder.createOperation(state);
op->replaceAllUsesWith(repOp);
op->erase();
return success();
}
LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
ArrayRef<Value *> extraIndices,
AffineMap indexRemap,
ArrayRef<Value *> extraOperands,
Operation *domInstFilter,
Operation *postDomInstFilter) {
unsigned newMemRefRank = newMemRef->getType().cast<MemRefType>().getRank();
(void)newMemRefRank; // unused in opt mode
unsigned oldMemRefRank = oldMemRef->getType().cast<MemRefType>().getRank();
(void)oldMemRefRank;
if (indexRemap) {
assert(indexRemap.getNumSymbols() == 0 && "pure dimensional map expected");
assert(indexRemap.getNumInputs() == extraOperands.size() + oldMemRefRank);
@ -89,170 +254,44 @@ bool mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
postDomInfo = std::make_unique<PostDominanceInfo>(
postDomInstFilter->getParentOfType<FuncOp>());
// The ops where memref replacement succeeds are replaced with new ones.
SmallVector<Operation *, 8> opsToErase;
// Walk all uses of old memref. Operation using the memref gets replaced.
for (auto *opInst : llvm::make_early_inc_range(oldMemRef->getUsers())) {
// Walk all uses of old memref; collect ops to perform replacement. We use a
// DenseSet since an operation could potentially have multiple uses of a
// memref (although rare), and the replacement later is going to erase ops.
DenseSet<Operation *> opsToReplace;
for (auto *op : oldMemRef->getUsers()) {
// Skip this use if it's not dominated by domInstFilter.
if (domInstFilter && !domInfo->dominates(domInstFilter, opInst))
if (domInstFilter && !domInfo->dominates(domInstFilter, op))
continue;
// Skip this use if it's not post-dominated by postDomInstFilter.
if (postDomInstFilter &&
!postDomInfo->postDominates(postDomInstFilter, opInst))
if (postDomInstFilter && !postDomInfo->postDominates(postDomInstFilter, op))
continue;
// Skip dealloc's - no replacement is necessary, and a replacement doesn't
// hurt dealloc's.
if (isa<DeallocOp>(opInst))
// Skip dealloc's - no replacement is necessary, and a memref replacement
// at other uses doesn't hurt these dealloc's.
if (isa<DeallocOp>(op))
continue;
// Check if the memref was used in a non-deferencing context. It is fine for
// the memref to be used in a non-deferencing way outside of the region
// where this replacement is happening.
if (!isMemRefDereferencingOp(*opInst))
// Failure: memref used in a non-deferencing op (potentially escapes); no
// replacement in these cases.
return false;
// Check if the memref was used in a non-dereferencing context. It is fine
// for the memref to be used in a non-dereferencing way outside of the
// region where this replacement is happening.
if (!isMemRefDereferencingOp(*op))
// Failure: memref used in a non-dereferencing op (potentially escapes);
// no replacement in these cases.
return failure();
auto getMemRefOperandPos = [&]() -> unsigned {
unsigned i, e;
for (i = 0, e = opInst->getNumOperands(); i < e; i++) {
if (opInst->getOperand(i) == oldMemRef)
break;
}
assert(i < opInst->getNumOperands() && "operand guaranteed to be found");
return i;
};
OpBuilder builder(opInst);
unsigned memRefOperandPos = getMemRefOperandPos();
NamedAttribute oldMapAttrPair =
getAffineMapAttrForMemRef(opInst, oldMemRef);
AffineMap oldMap = oldMapAttrPair.second.cast<AffineMapAttr>().getValue();
unsigned oldMapNumInputs = oldMap.getNumInputs();
SmallVector<Value *, 4> oldMapOperands(
opInst->operand_begin() + memRefOperandPos + 1,
opInst->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs);
SmallVector<Value *, 4> affineApplyOps;
// Apply 'oldMemRefOperands = oldMap(oldMapOperands)'.
SmallVector<Value *, 4> oldMemRefOperands;
oldMemRefOperands.reserve(oldMemRefRank);
if (oldMap != builder.getMultiDimIdentityMap(oldMap.getNumDims())) {
for (auto resultExpr : oldMap.getResults()) {
auto singleResMap = builder.getAffineMap(
oldMap.getNumDims(), oldMap.getNumSymbols(), resultExpr);
auto afOp = builder.create<AffineApplyOp>(opInst->getLoc(),
singleResMap, oldMapOperands);
oldMemRefOperands.push_back(afOp);
affineApplyOps.push_back(afOp);
}
} else {
oldMemRefOperands.append(oldMapOperands.begin(), oldMapOperands.end());
}
// Construct new indices as a remap of the old ones if a remapping has been
// provided. The indices of a memref come right after it, i.e.,
// at position memRefOperandPos + 1.
SmallVector<Value *, 4> remapOperands;
remapOperands.reserve(extraOperands.size() + oldMemRefRank);
remapOperands.append(extraOperands.begin(), extraOperands.end());
remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end());
SmallVector<Value *, 4> remapOutputs;
remapOutputs.reserve(oldMemRefRank);
if (indexRemap &&
indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) {
// Remapped indices.
for (auto resultExpr : indexRemap.getResults()) {
auto singleResMap = builder.getAffineMap(
indexRemap.getNumDims(), indexRemap.getNumSymbols(), resultExpr);
auto afOp = builder.create<AffineApplyOp>(opInst->getLoc(),
singleResMap, remapOperands);
remapOutputs.push_back(afOp);
affineApplyOps.push_back(afOp);
}
} else {
// No remapping specified.
remapOutputs.append(remapOperands.begin(), remapOperands.end());
}
SmallVector<Value *, 4> newMapOperands;
newMapOperands.reserve(newMemRefRank);
// Prepend 'extraIndices' in 'newMapOperands'.
for (auto *extraIndex : extraIndices) {
assert(extraIndex->getDefiningOp()->getNumResults() == 1 &&
"single result op's expected to generate these indices");
assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) &&
"invalid memory op index");
newMapOperands.push_back(extraIndex);
}
// Append 'remapOutputs' to 'newMapOperands'.
newMapOperands.append(remapOutputs.begin(), remapOutputs.end());
// Create new fully composed AffineMap for new op to be created.
assert(newMapOperands.size() == newMemRefRank);
auto newMap = builder.getMultiDimIdentityMap(newMemRefRank);
// TODO(b/136262594) Avoid creating/deleting temporary AffineApplyOps here.
fullyComposeAffineMapAndOperands(&newMap, &newMapOperands);
newMap = simplifyAffineMap(newMap);
canonicalizeMapAndOperands(&newMap, &newMapOperands);
// Remove any affine.apply's that became dead as a result of composition.
for (auto *value : affineApplyOps)
if (value->use_empty())
value->getDefiningOp()->erase();
// Construct the new operation using this memref.
OperationState state(opInst->getLoc(), opInst->getName());
state.setOperandListToResizable(opInst->hasResizableOperandsList());
state.operands.reserve(opInst->getNumOperands() + extraIndices.size());
// Insert the non-memref operands.
state.operands.append(opInst->operand_begin(),
opInst->operand_begin() + memRefOperandPos);
// Insert the new memref value.
state.operands.push_back(newMemRef);
// Insert the new memref map operands.
state.operands.append(newMapOperands.begin(), newMapOperands.end());
// Insert the remaining operands unmodified.
state.operands.append(opInst->operand_begin() + memRefOperandPos + 1 +
oldMapNumInputs,
opInst->operand_end());
// Result types don't change. Both memref's are of the same elemental type.
state.types.reserve(opInst->getNumResults());
for (auto *result : opInst->getResults())
state.types.push_back(result->getType());
// Add attribute for 'newMap', other Attributes do not change.
auto newMapAttr = builder.getAffineMapAttr(newMap);
for (auto namedAttr : opInst->getAttrs()) {
if (namedAttr.first == oldMapAttrPair.first) {
state.attributes.push_back({namedAttr.first, newMapAttr});
} else {
state.attributes.push_back(namedAttr);
}
}
// Create the new operation.
auto *repOp = builder.createOperation(state);
opInst->replaceAllUsesWith(repOp);
// Collect and erase at the end since one of these op's could be
// domInstFilter or postDomInstFilter as well!
opsToErase.push_back(opInst);
// We'll first collect and then replace --- since replacement erases the op
// that has the use, and that op could be postDomFilter or domFilter itself!
opsToReplace.insert(op);
}
for (auto *opInst : opsToErase)
opInst->erase();
for (auto *op : opsToReplace) {
if (failed(replaceAllMemRefUsesWith(oldMemRef, newMemRef, op, extraIndices,
indexRemap, extraOperands)))
assert(false && "memref replacement guaranteed to succeed here");
}
return true;
return success();
}
/// Given an operation, inserts one or more single result affine