Integer set + operands / affine if op canonicalization

- turn canonicalizeMapAndOperands into a template that works on both
  sets and maps, and use it to introduce a utility to canonicalize an
  affine integer set and its operands
- add pattern to canonicalize affine if op's.
- rename IntegerSet::getNumOperands -> IntegerSet::getNumInputs to be
  consistent with AffineMap
- add missing accessors for IntegerSet

Doesn't need extensive testing since canonicalizeSetAndOperands just
reuses canonicalizeMapAndOperands' logic, and the latter is tested on
affine.apply map + operands; the new method works the same way on an
integer set + operands of an affine if op for example.

Signed-off-by: Uday Bondhugula <uday@polymagelabs.com>

Closes 

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/112 from bondhugula:set-canonicalize eff72f23250b96fa7d9f5caff3877440f5de2cec
PiperOrigin-RevId: 267532876
This commit is contained in:
Uday Bondhugula 2019-09-05 23:12:01 -07:00 committed by TensorFlower Gardener
parent cf27972e28
commit 2ad042991c
6 changed files with 139 additions and 38 deletions
third_party/mlir
include/mlir
lib

View File

@ -522,6 +522,10 @@ bool isValidSymbol(Value *value);
/// 2. drop unused dims and symbols from map
void canonicalizeMapAndOperands(AffineMap *map,
llvm::SmallVectorImpl<Value *> *operands);
/// Canonicalizes an integer set the same way canonicalizeMapAndOperands does
/// for affine maps.
void canonicalizeSetAndOperands(IntegerSet *set,
llvm::SmallVectorImpl<Value *> *operands);
/// Returns a composed AffineApplyOp by composing `map` and `operands` with
/// other AffineApplyOps supplying those operands. The operands of the resulting

View File

@ -223,6 +223,11 @@ def AffineIfOp : Affine_Op<"if", [ImplicitAffineTerminator]> {
IntegerSet getIntegerSet();
void setIntegerSet(IntegerSet newSet);
/// Sets the integer set with its operands. The size of 'operands' must not
/// exceed the current number of operands for this instance, as the operands
/// list of AffineIf is not resizable.
void setConditional(IntegerSet set, ArrayRef<Value *> operands);
OpBuilder getThenBodyBuilder() {
assert(!thenRegion().empty() && "Unexpected empty 'then' region.");
Block &body = thenRegion().front();
@ -234,6 +239,8 @@ def AffineIfOp : Affine_Op<"if", [ImplicitAffineTerminator]> {
return OpBuilder(&body, std::prev(body.end()));
}
}];
let hasCanonicalizer = 1;
}
def AffineTerminatorOp :

View File

@ -72,12 +72,22 @@ public:
/// Returns true if this is the canonical integer set.
bool isEmptyIntegerSet() const;
/// This method substitutes any uses of dimensions and symbols (e.g.
/// dim#0 with dimReplacements[0]) in subexpressions and returns the modified
/// integer set. Because this can be used to eliminate dims and
/// symbols, the client needs to specify the number of dims and symbols in
/// the result. The returned map always has the same number of results.
IntegerSet replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,
ArrayRef<AffineExpr> symReplacements,
unsigned numResultDims,
unsigned numResultSyms);
explicit operator bool() { return set; }
bool operator==(IntegerSet other) const { return set == other.set; }
unsigned getNumDims() const;
unsigned getNumSymbols() const;
unsigned getNumOperands() const;
unsigned getNumInputs() const;
unsigned getNumConstraints() const;
unsigned getNumEqualities() const;
unsigned getNumInequalities() const;
@ -96,6 +106,10 @@ public:
MLIRContext *getContext() const;
/// Walk all of the AffineExpr's in this set's constraints. Each node in an
/// expression tree is visited in postorder.
void walkExprs(llvm::function_ref<void(AffineExpr)> callback) const;
void print(raw_ostream &os) const;
void dump() const;

View File

@ -308,7 +308,7 @@ std::unique_ptr<FlatAffineConstraints> FlatAffineConstraints::clone() const {
// Construct from an IntegerSet.
FlatAffineConstraints::FlatAffineConstraints(IntegerSet set)
: numReservedCols(set.getNumOperands() + 1),
: numReservedCols(set.getNumInputs() + 1),
numIds(set.getNumDims() + set.getNumSymbols()), numDims(set.getNumDims()),
numSymbols(set.getNumSymbols()) {
equalities.reserve(set.getNumEqualities() * numReservedCols);

View File

@ -620,26 +620,27 @@ AffineApplyOp mlir::makeComposedAffineApply(OpBuilder &b, Location loc,
// A symbol may appear as a dim in affine.apply operations. This function
// canonicalizes dims that are valid symbols into actual symbols.
template <class MapOrSet>
static void
canonicalizePromotedSymbols(AffineMap *map,
canonicalizePromotedSymbols(MapOrSet *mapOrSet,
llvm::SmallVectorImpl<Value *> *operands) {
if (!map || operands->empty())
if (!mapOrSet || operands->empty())
return;
assert(map->getNumInputs() == operands->size() &&
"map inputs must match number of operands");
assert(mapOrSet->getNumInputs() == operands->size() &&
"map/set inputs must match number of operands");
auto *context = map->getContext();
auto *context = mapOrSet->getContext();
SmallVector<Value *, 8> resultOperands;
resultOperands.reserve(operands->size());
SmallVector<Value *, 8> remappedSymbols;
remappedSymbols.reserve(operands->size());
unsigned nextDim = 0;
unsigned nextSym = 0;
unsigned oldNumSyms = map->getNumSymbols();
SmallVector<AffineExpr, 8> dimRemapping(map->getNumDims());
for (unsigned i = 0, e = map->getNumInputs(); i != e; ++i) {
if (i < map->getNumDims()) {
unsigned oldNumSyms = mapOrSet->getNumSymbols();
SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
for (unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) {
if (i < mapOrSet->getNumDims()) {
if (isValidSymbol((*operands)[i])) {
// This is a valid symbol that appears as a dim, canonicalize it.
dimRemapping[i] = getAffineSymbolExpr(oldNumSyms + nextSym++, context);
@ -655,42 +656,49 @@ canonicalizePromotedSymbols(AffineMap *map,
resultOperands.append(remappedSymbols.begin(), remappedSymbols.end());
*operands = resultOperands;
*map = map->replaceDimsAndSymbols(dimRemapping, {}, nextDim,
oldNumSyms + nextSym);
*mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, {}, nextDim,
oldNumSyms + nextSym);
assert(map->getNumInputs() == operands->size() &&
"map inputs must match number of operands");
assert(mapOrSet->getNumInputs() == operands->size() &&
"map/set inputs must match number of operands");
}
void mlir::canonicalizeMapAndOperands(
AffineMap *map, llvm::SmallVectorImpl<Value *> *operands) {
if (!map || operands->empty())
// Works for either an affine map or an integer set.
template <class MapOrSet>
static void
canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet,
llvm::SmallVectorImpl<Value *> *operands) {
static_assert(std::is_same<MapOrSet, AffineMap>::value ||
std::is_same<MapOrSet, IntegerSet>::value,
"Argument must be either of AffineMap or IntegerSet type");
if (!mapOrSet || operands->empty())
return;
assert(map->getNumInputs() == operands->size() &&
"map inputs must match number of operands");
assert(mapOrSet->getNumInputs() == operands->size() &&
"map/set inputs must match number of operands");
canonicalizePromotedSymbols(map, operands);
canonicalizePromotedSymbols<MapOrSet>(mapOrSet, operands);
// Check to see what dims are used.
llvm::SmallBitVector usedDims(map->getNumDims());
llvm::SmallBitVector usedSyms(map->getNumSymbols());
map->walkExprs([&](AffineExpr expr) {
llvm::SmallBitVector usedDims(mapOrSet->getNumDims());
llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols());
mapOrSet->walkExprs([&](AffineExpr expr) {
if (auto dimExpr = expr.dyn_cast<AffineDimExpr>())
usedDims[dimExpr.getPosition()] = true;
else if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>())
usedSyms[symExpr.getPosition()] = true;
});
auto *context = map->getContext();
auto *context = mapOrSet->getContext();
SmallVector<Value *, 8> resultOperands;
resultOperands.reserve(operands->size());
llvm::SmallDenseMap<Value *, AffineExpr, 8> seenDims;
SmallVector<AffineExpr, 8> dimRemapping(map->getNumDims());
SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
unsigned nextDim = 0;
for (unsigned i = 0, e = map->getNumDims(); i != e; ++i) {
for (unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) {
if (usedDims[i]) {
// Remap dim positions for duplicate operands.
auto it = seenDims.find((*operands)[i]);
@ -704,37 +712,47 @@ void mlir::canonicalizeMapAndOperands(
}
}
llvm::SmallDenseMap<Value *, AffineExpr, 8> seenSymbols;
SmallVector<AffineExpr, 8> symRemapping(map->getNumSymbols());
SmallVector<AffineExpr, 8> symRemapping(mapOrSet->getNumSymbols());
unsigned nextSym = 0;
for (unsigned i = 0, e = map->getNumSymbols(); i != e; ++i) {
for (unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) {
if (!usedSyms[i])
continue;
// Handle constant operands (only needed for symbolic operands since
// constant operands in dimensional positions would have already been
// promoted to symbolic positions above).
IntegerAttr operandCst;
if (matchPattern((*operands)[i + map->getNumDims()],
if (matchPattern((*operands)[i + mapOrSet->getNumDims()],
m_Constant(&operandCst))) {
symRemapping[i] =
getAffineConstantExpr(operandCst.getValue().getSExtValue(), context);
continue;
}
// Remap symbol positions for duplicate operands.
auto it = seenSymbols.find((*operands)[i + map->getNumDims()]);
auto it = seenSymbols.find((*operands)[i + mapOrSet->getNumDims()]);
if (it == seenSymbols.end()) {
symRemapping[i] = getAffineSymbolExpr(nextSym++, context);
resultOperands.push_back((*operands)[i + map->getNumDims()]);
seenSymbols.insert(
std::make_pair((*operands)[i + map->getNumDims()], symRemapping[i]));
resultOperands.push_back((*operands)[i + mapOrSet->getNumDims()]);
seenSymbols.insert(std::make_pair((*operands)[i + mapOrSet->getNumDims()],
symRemapping[i]));
} else {
symRemapping[i] = it->second;
}
}
*map =
map->replaceDimsAndSymbols(dimRemapping, symRemapping, nextDim, nextSym);
*mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, symRemapping,
nextDim, nextSym);
*operands = resultOperands;
}
void mlir::canonicalizeMapAndOperands(
AffineMap *map, llvm::SmallVectorImpl<Value *> *operands) {
canonicalizeMapOrSetAndOperands<AffineMap>(map, operands);
}
void mlir::canonicalizeSetAndOperands(
IntegerSet *set, llvm::SmallVectorImpl<Value *> *operands) {
canonicalizeMapOrSetAndOperands<IntegerSet>(set, operands);
}
namespace {
/// Simplify AffineApply operations.
///
@ -1540,7 +1558,7 @@ static LogicalResult verify(AffineIfOp op) {
// Verify that there are enough operands for the condition.
IntegerSet condition = conditionAttr.getValue();
if (op.getNumOperands() != condition.getNumOperands())
if (op.getNumOperands() != condition.getNumInputs())
return op.emitOpError(
"operand count and condition integer set dimension and "
"symbol count must match");
@ -1639,6 +1657,44 @@ void AffineIfOp::setIntegerSet(IntegerSet newSet) {
setAttr(getConditionAttrName(), IntegerSetAttr::get(newSet));
}
void AffineIfOp::setConditional(IntegerSet set, ArrayRef<Value *> operands) {
setIntegerSet(set);
getOperation()->setOperands(operands);
}
namespace {
// This is a pattern to canonicalize an affine if op's conditional (integer
// set + operands).
struct AffineIfOpCanonicalizer : public OpRewritePattern<AffineIfOp> {
using OpRewritePattern<AffineIfOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(AffineIfOp ifOp,
PatternRewriter &rewriter) const override {
auto set = ifOp.getIntegerSet();
SmallVector<Value *, 4> operands(ifOp.getOperands());
canonicalizeSetAndOperands(&set, &operands);
// Any canonicalization change always leads to either a reduction in the
// number of operands or a change in the number of symbolic operands
// (promotion of dims to symbols).
if (operands.size() < ifOp.getIntegerSet().getNumInputs() ||
set.getNumSymbols() > ifOp.getIntegerSet().getNumSymbols()) {
ifOp.setConditional(set, operands);
rewriter.updatedRootInPlace(ifOp);
return matchSuccess();
}
return matchFailure();
}
};
} // end anonymous namespace
void AffineIfOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<AffineIfOpCanonicalizer>(context);
}
//===----------------------------------------------------------------------===//
// AffineLoadOp
//===----------------------------------------------------------------------===//

View File

@ -24,7 +24,7 @@ using namespace mlir::detail;
unsigned IntegerSet::getNumDims() const { return set->dimCount; }
unsigned IntegerSet::getNumSymbols() const { return set->symbolCount; }
unsigned IntegerSet::getNumOperands() const {
unsigned IntegerSet::getNumInputs() const {
return set->dimCount + set->symbolCount;
}
@ -70,3 +70,23 @@ bool IntegerSet::isEq(unsigned idx) const { return getEqFlags()[idx]; }
MLIRContext *IntegerSet::getContext() const {
return getConstraint(0).getContext();
}
/// Walk all of the AffineExpr's in this set. Each node in an expression
/// tree is visited in postorder.
void IntegerSet::walkExprs(
llvm::function_ref<void(AffineExpr)> callback) const {
for (auto expr : getConstraints())
expr.walk(callback);
}
IntegerSet IntegerSet::replaceDimsAndSymbols(
ArrayRef<AffineExpr> dimReplacements, ArrayRef<AffineExpr> symReplacements,
unsigned numResultDims, unsigned numResultSyms) {
SmallVector<AffineExpr, 8> constraints;
constraints.reserve(getNumConstraints());
for (auto cst : getConstraints())
constraints.push_back(
cst.replaceDimsAndSymbols(dimReplacements, symReplacements));
return get(numResultDims, numResultSyms, constraints, getEqFlags());
}