Insert signature-converted blocks into a region with a parent operation.

This keeps the IR valid and consistent as it is expected that each block should have a valid parent region/operation. Previously, converted blocks were kept floating without a valid parent region.

PiperOrigin-RevId: 285821687
Change-Id: I9650b562d25b3956becc1c8b46c7af5e83694eae
This commit is contained in:
River Riddle 2019-12-16 12:09:14 -08:00 committed by TensorFlower Gardener
parent e1666f32ba
commit 8f2154320a
3 changed files with 61 additions and 26 deletions

View File

@ -56,10 +56,14 @@ public:
/// Return if this block is the entry block in the parent region.
bool isEntryBlock();
/// Insert this block (which must not already be in a function) right before
/// Insert this block (which must not already be in a region) right before
/// the specified block.
void insertBefore(Block *block);
/// Unlink this block from its current region and insert it right before the
/// specific block.
void moveBefore(Block *block);
/// Unlink this Block from its parent region and delete it.
void erase();

View File

@ -57,7 +57,15 @@ bool Block::isEntryBlock() { return this == &getParent()->front(); }
void Block::insertBefore(Block *block) {
assert(!getParent() && "already inserted into a block!");
assert(block->getParent() && "cannot insert before a block without a parent");
block->getParent()->getBlocks().insert(Region::iterator(block), this);
block->getParent()->getBlocks().insert(block->getIterator(), this);
}
/// Unlink this block from its current region and insert it right before the
/// specific block.
void Block::moveBefore(Block *block) {
assert(block->getParent() && "cannot insert before a block without a parent");
block->getParent()->getBlocks().splice(
block->getIterator(), getParent()->getBlocks(), getIterator());
}
/// Unlink this Block from its parent Region and delete it.

View File

@ -164,9 +164,10 @@ struct ArgConverter {
// Rewrite Application
//===--------------------------------------------------------------------===//
/// Erase any rewrites registered for the current block that is about to be
/// removed. This merely drops the rewrites without undoing them.
void notifyBlockRemoved(Block *block);
/// Erase any rewrites registered for the blocks within the given operation
/// which is about to be removed. This merely drops the rewrites without
/// undoing them.
void notifyOpRemoved(Operation *op);
/// Cleanup and undo any generated conversions for the arguments of block.
/// This method replaces the new block with the original, reverting the IR to
@ -194,9 +195,16 @@ struct ArgConverter {
Block *block, TypeConverter::SignatureConversion &signatureConversion,
ConversionValueMapping &mapping);
/// Insert a new conversion into the cache.
void insertConversion(Block *newBlock, ConvertedBlockInfo &&info);
/// A collection of blocks that have had their arguments converted.
llvm::MapVector<Block *, ConvertedBlockInfo> conversionInfo;
/// A mapping from valid regions, to those containing the original blocks of a
/// conversion.
DenseMap<Region *, std::unique_ptr<Region>> regionMapping;
/// An instance of the unknown location that is used when materializing
/// conversions.
Location loc;
@ -212,18 +220,26 @@ struct ArgConverter {
//===----------------------------------------------------------------------===//
// Rewrite Application
void ArgConverter::notifyBlockRemoved(Block *block) {
auto it = conversionInfo.find(block);
if (it == conversionInfo.end())
return;
void ArgConverter::notifyOpRemoved(Operation *op) {
for (Region &region : op->getRegions()) {
for (Block &block : region) {
// Drop any rewrites from within.
for (Operation &nestedOp : block)
if (nestedOp.getNumRegions())
notifyOpRemoved(&nestedOp);
// Drop all uses of the original arguments and delete the original block.
Block *origBlock = it->second.origBlock;
for (BlockArgument *arg : origBlock->getArguments())
arg->dropAllUses();
delete origBlock;
// Check if this block was converted.
auto it = conversionInfo.find(&block);
if (it == conversionInfo.end())
return;
conversionInfo.erase(it);
// Drop all uses of the original arguments and delete the original block.
Block *origBlock = it->second.origBlock;
for (BlockArgument *arg : origBlock->getArguments())
arg->dropAllUses();
conversionInfo.erase(it);
}
}
}
void ArgConverter::discardRewrites(Block *block) {
@ -239,7 +255,7 @@ void ArgConverter::discardRewrites(Block *block) {
// Move the operations back the original block and the delete the new block.
origBlock->getOperations().splice(origBlock->end(), block->getOperations());
origBlock->insertBefore(block);
origBlock->moveBefore(block);
block->erase();
conversionInfo.erase(it);
@ -301,9 +317,6 @@ void ArgConverter::applyRewrites(ConversionValueMapping &mapping) {
if (castValue->use_empty())
castValue->getDefiningOp()->erase();
}
// Drop the original block now the rewrites were applied.
delete origBlock;
}
}
@ -377,11 +390,24 @@ Block *ArgConverter::applySignatureConversion(
}
// Remove the original block from the region and return the new one.
newBlock->getParent()->getBlocks().remove(block);
conversionInfo.insert({newBlock, std::move(info)});
insertConversion(newBlock, std::move(info));
return newBlock;
}
void ArgConverter::insertConversion(Block *newBlock,
ConvertedBlockInfo &&info) {
// Get a region to insert the old block.
Region *region = newBlock->getParent();
std::unique_ptr<Region> &mappedRegion = regionMapping[region];
if (!mappedRegion)
mappedRegion = std::make_unique<Region>(region->getParentOp());
// Move the original block to the mapped region and emplace the conversion.
mappedRegion->getBlocks().splice(mappedRegion->end(), region->getBlocks(),
info.origBlock->getIterator());
conversionInfo.insert({newBlock, std::move(info)});
}
//===----------------------------------------------------------------------===//
// ConversionPatternRewriterImpl
//===----------------------------------------------------------------------===//
@ -642,11 +668,8 @@ void ConversionPatternRewriterImpl::applyRewrites() {
// If this operation defines any regions, drop any pending argument
// rewrites.
if (argConverter.typeConverter && repl.op->getNumRegions()) {
for (auto &region : repl.op->getRegions())
for (auto &block : region)
argConverter.notifyBlockRemoved(&block);
}
if (argConverter.typeConverter && repl.op->getNumRegions())
argConverter.notifyOpRemoved(repl.op);
}
// In a second pass, erase all of the replaced operations in reverse. This