From 2ad042991c032bf38ea96b34814d44816fc7f78b Mon Sep 17 00:00:00 2001 From: Uday Bondhugula Date: Thu, 5 Sep 2019 23:12:01 -0700 Subject: [PATCH] Integer set + operands / affine if op canonicalization - turn canonicalizeMapAndOperands into a template that works on both sets and maps, and use it to introduce a utility to canonicalize an affine integer set and its operands - add pattern to canonicalize affine if op's. - rename IntegerSet::getNumOperands -> IntegerSet::getNumInputs to be consistent with AffineMap - add missing accessors for IntegerSet Doesn't need extensive testing since canonicalizeSetAndOperands just reuses canonicalizeMapAndOperands' logic, and the latter is tested on affine.apply map + operands; the new method works the same way on an integer set + operands of an affine if op for example. Signed-off-by: Uday Bondhugula Closes #112 COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/112 from bondhugula:set-canonicalize eff72f23250b96fa7d9f5caff3877440f5de2cec PiperOrigin-RevId: 267532876 --- .../mlir/Dialect/AffineOps/AffineOps.h | 4 + .../mlir/Dialect/AffineOps/AffineOps.td | 7 + third_party/mlir/include/mlir/IR/IntegerSet.h | 16 ++- .../mlir/lib/Analysis/AffineStructures.cpp | 2 +- .../mlir/lib/Dialect/AffineOps/AffineOps.cpp | 126 +++++++++++++----- third_party/mlir/lib/IR/IntegerSet.cpp | 22 ++- 6 files changed, 139 insertions(+), 38 deletions(-) diff --git a/third_party/mlir/include/mlir/Dialect/AffineOps/AffineOps.h b/third_party/mlir/include/mlir/Dialect/AffineOps/AffineOps.h index a6af20eca0b..03b945c0a5f 100644 --- a/third_party/mlir/include/mlir/Dialect/AffineOps/AffineOps.h +++ b/third_party/mlir/include/mlir/Dialect/AffineOps/AffineOps.h @@ -522,6 +522,10 @@ bool isValidSymbol(Value *value); /// 2. drop unused dims and symbols from map void canonicalizeMapAndOperands(AffineMap *map, llvm::SmallVectorImpl *operands); +/// Canonicalizes an integer set the same way canonicalizeMapAndOperands does +/// for affine maps. +void canonicalizeSetAndOperands(IntegerSet *set, + llvm::SmallVectorImpl *operands); /// Returns a composed AffineApplyOp by composing `map` and `operands` with /// other AffineApplyOps supplying those operands. The operands of the resulting diff --git a/third_party/mlir/include/mlir/Dialect/AffineOps/AffineOps.td b/third_party/mlir/include/mlir/Dialect/AffineOps/AffineOps.td index 237692c04a7..4961ce8ee95 100644 --- a/third_party/mlir/include/mlir/Dialect/AffineOps/AffineOps.td +++ b/third_party/mlir/include/mlir/Dialect/AffineOps/AffineOps.td @@ -223,6 +223,11 @@ def AffineIfOp : Affine_Op<"if", [ImplicitAffineTerminator]> { IntegerSet getIntegerSet(); void setIntegerSet(IntegerSet newSet); + /// Sets the integer set with its operands. The size of 'operands' must not + /// exceed the current number of operands for this instance, as the operands + /// list of AffineIf is not resizable. + void setConditional(IntegerSet set, ArrayRef operands); + OpBuilder getThenBodyBuilder() { assert(!thenRegion().empty() && "Unexpected empty 'then' region."); Block &body = thenRegion().front(); @@ -234,6 +239,8 @@ def AffineIfOp : Affine_Op<"if", [ImplicitAffineTerminator]> { return OpBuilder(&body, std::prev(body.end())); } }]; + + let hasCanonicalizer = 1; } def AffineTerminatorOp : diff --git a/third_party/mlir/include/mlir/IR/IntegerSet.h b/third_party/mlir/include/mlir/IR/IntegerSet.h index b7662f095a5..e989f91bafd 100644 --- a/third_party/mlir/include/mlir/IR/IntegerSet.h +++ b/third_party/mlir/include/mlir/IR/IntegerSet.h @@ -72,12 +72,22 @@ public: /// Returns true if this is the canonical integer set. bool isEmptyIntegerSet() const; + /// This method substitutes any uses of dimensions and symbols (e.g. + /// dim#0 with dimReplacements[0]) in subexpressions and returns the modified + /// integer set. Because this can be used to eliminate dims and + /// symbols, the client needs to specify the number of dims and symbols in + /// the result. The returned map always has the same number of results. + IntegerSet replaceDimsAndSymbols(ArrayRef dimReplacements, + ArrayRef symReplacements, + unsigned numResultDims, + unsigned numResultSyms); + explicit operator bool() { return set; } bool operator==(IntegerSet other) const { return set == other.set; } unsigned getNumDims() const; unsigned getNumSymbols() const; - unsigned getNumOperands() const; + unsigned getNumInputs() const; unsigned getNumConstraints() const; unsigned getNumEqualities() const; unsigned getNumInequalities() const; @@ -96,6 +106,10 @@ public: MLIRContext *getContext() const; + /// Walk all of the AffineExpr's in this set's constraints. Each node in an + /// expression tree is visited in postorder. + void walkExprs(llvm::function_ref callback) const; + void print(raw_ostream &os) const; void dump() const; diff --git a/third_party/mlir/lib/Analysis/AffineStructures.cpp b/third_party/mlir/lib/Analysis/AffineStructures.cpp index f660fff7df6..2804ac68b4b 100644 --- a/third_party/mlir/lib/Analysis/AffineStructures.cpp +++ b/third_party/mlir/lib/Analysis/AffineStructures.cpp @@ -308,7 +308,7 @@ std::unique_ptr FlatAffineConstraints::clone() const { // Construct from an IntegerSet. FlatAffineConstraints::FlatAffineConstraints(IntegerSet set) - : numReservedCols(set.getNumOperands() + 1), + : numReservedCols(set.getNumInputs() + 1), numIds(set.getNumDims() + set.getNumSymbols()), numDims(set.getNumDims()), numSymbols(set.getNumSymbols()) { equalities.reserve(set.getNumEqualities() * numReservedCols); diff --git a/third_party/mlir/lib/Dialect/AffineOps/AffineOps.cpp b/third_party/mlir/lib/Dialect/AffineOps/AffineOps.cpp index c6abc05b966..2161ae0d164 100644 --- a/third_party/mlir/lib/Dialect/AffineOps/AffineOps.cpp +++ b/third_party/mlir/lib/Dialect/AffineOps/AffineOps.cpp @@ -620,26 +620,27 @@ AffineApplyOp mlir::makeComposedAffineApply(OpBuilder &b, Location loc, // A symbol may appear as a dim in affine.apply operations. This function // canonicalizes dims that are valid symbols into actual symbols. +template static void -canonicalizePromotedSymbols(AffineMap *map, +canonicalizePromotedSymbols(MapOrSet *mapOrSet, llvm::SmallVectorImpl *operands) { - if (!map || operands->empty()) + if (!mapOrSet || operands->empty()) return; - assert(map->getNumInputs() == operands->size() && - "map inputs must match number of operands"); + assert(mapOrSet->getNumInputs() == operands->size() && + "map/set inputs must match number of operands"); - auto *context = map->getContext(); + auto *context = mapOrSet->getContext(); SmallVector resultOperands; resultOperands.reserve(operands->size()); SmallVector remappedSymbols; remappedSymbols.reserve(operands->size()); unsigned nextDim = 0; unsigned nextSym = 0; - unsigned oldNumSyms = map->getNumSymbols(); - SmallVector dimRemapping(map->getNumDims()); - for (unsigned i = 0, e = map->getNumInputs(); i != e; ++i) { - if (i < map->getNumDims()) { + unsigned oldNumSyms = mapOrSet->getNumSymbols(); + SmallVector dimRemapping(mapOrSet->getNumDims()); + for (unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) { + if (i < mapOrSet->getNumDims()) { if (isValidSymbol((*operands)[i])) { // This is a valid symbol that appears as a dim, canonicalize it. dimRemapping[i] = getAffineSymbolExpr(oldNumSyms + nextSym++, context); @@ -655,42 +656,49 @@ canonicalizePromotedSymbols(AffineMap *map, resultOperands.append(remappedSymbols.begin(), remappedSymbols.end()); *operands = resultOperands; - *map = map->replaceDimsAndSymbols(dimRemapping, {}, nextDim, - oldNumSyms + nextSym); + *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, {}, nextDim, + oldNumSyms + nextSym); - assert(map->getNumInputs() == operands->size() && - "map inputs must match number of operands"); + assert(mapOrSet->getNumInputs() == operands->size() && + "map/set inputs must match number of operands"); } -void mlir::canonicalizeMapAndOperands( - AffineMap *map, llvm::SmallVectorImpl *operands) { - if (!map || operands->empty()) +// Works for either an affine map or an integer set. +template +static void +canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet, + llvm::SmallVectorImpl *operands) { + static_assert(std::is_same::value || + std::is_same::value, + "Argument must be either of AffineMap or IntegerSet type"); + + if (!mapOrSet || operands->empty()) return; - assert(map->getNumInputs() == operands->size() && - "map inputs must match number of operands"); + assert(mapOrSet->getNumInputs() == operands->size() && + "map/set inputs must match number of operands"); - canonicalizePromotedSymbols(map, operands); + canonicalizePromotedSymbols(mapOrSet, operands); // Check to see what dims are used. - llvm::SmallBitVector usedDims(map->getNumDims()); - llvm::SmallBitVector usedSyms(map->getNumSymbols()); - map->walkExprs([&](AffineExpr expr) { + llvm::SmallBitVector usedDims(mapOrSet->getNumDims()); + llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols()); + mapOrSet->walkExprs([&](AffineExpr expr) { if (auto dimExpr = expr.dyn_cast()) usedDims[dimExpr.getPosition()] = true; else if (auto symExpr = expr.dyn_cast()) usedSyms[symExpr.getPosition()] = true; }); - auto *context = map->getContext(); + auto *context = mapOrSet->getContext(); SmallVector resultOperands; resultOperands.reserve(operands->size()); llvm::SmallDenseMap seenDims; - SmallVector dimRemapping(map->getNumDims()); + SmallVector dimRemapping(mapOrSet->getNumDims()); unsigned nextDim = 0; - for (unsigned i = 0, e = map->getNumDims(); i != e; ++i) { + for (unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) { if (usedDims[i]) { // Remap dim positions for duplicate operands. auto it = seenDims.find((*operands)[i]); @@ -704,37 +712,47 @@ void mlir::canonicalizeMapAndOperands( } } llvm::SmallDenseMap seenSymbols; - SmallVector symRemapping(map->getNumSymbols()); + SmallVector symRemapping(mapOrSet->getNumSymbols()); unsigned nextSym = 0; - for (unsigned i = 0, e = map->getNumSymbols(); i != e; ++i) { + for (unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) { if (!usedSyms[i]) continue; // Handle constant operands (only needed for symbolic operands since // constant operands in dimensional positions would have already been // promoted to symbolic positions above). IntegerAttr operandCst; - if (matchPattern((*operands)[i + map->getNumDims()], + if (matchPattern((*operands)[i + mapOrSet->getNumDims()], m_Constant(&operandCst))) { symRemapping[i] = getAffineConstantExpr(operandCst.getValue().getSExtValue(), context); continue; } // Remap symbol positions for duplicate operands. - auto it = seenSymbols.find((*operands)[i + map->getNumDims()]); + auto it = seenSymbols.find((*operands)[i + mapOrSet->getNumDims()]); if (it == seenSymbols.end()) { symRemapping[i] = getAffineSymbolExpr(nextSym++, context); - resultOperands.push_back((*operands)[i + map->getNumDims()]); - seenSymbols.insert( - std::make_pair((*operands)[i + map->getNumDims()], symRemapping[i])); + resultOperands.push_back((*operands)[i + mapOrSet->getNumDims()]); + seenSymbols.insert(std::make_pair((*operands)[i + mapOrSet->getNumDims()], + symRemapping[i])); } else { symRemapping[i] = it->second; } } - *map = - map->replaceDimsAndSymbols(dimRemapping, symRemapping, nextDim, nextSym); + *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, symRemapping, + nextDim, nextSym); *operands = resultOperands; } +void mlir::canonicalizeMapAndOperands( + AffineMap *map, llvm::SmallVectorImpl *operands) { + canonicalizeMapOrSetAndOperands(map, operands); +} + +void mlir::canonicalizeSetAndOperands( + IntegerSet *set, llvm::SmallVectorImpl *operands) { + canonicalizeMapOrSetAndOperands(set, operands); +} + namespace { /// Simplify AffineApply operations. /// @@ -1540,7 +1558,7 @@ static LogicalResult verify(AffineIfOp op) { // Verify that there are enough operands for the condition. IntegerSet condition = conditionAttr.getValue(); - if (op.getNumOperands() != condition.getNumOperands()) + if (op.getNumOperands() != condition.getNumInputs()) return op.emitOpError( "operand count and condition integer set dimension and " "symbol count must match"); @@ -1639,6 +1657,44 @@ void AffineIfOp::setIntegerSet(IntegerSet newSet) { setAttr(getConditionAttrName(), IntegerSetAttr::get(newSet)); } +void AffineIfOp::setConditional(IntegerSet set, ArrayRef operands) { + setIntegerSet(set); + getOperation()->setOperands(operands); +} + +namespace { +// This is a pattern to canonicalize an affine if op's conditional (integer +// set + operands). +struct AffineIfOpCanonicalizer : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(AffineIfOp ifOp, + PatternRewriter &rewriter) const override { + auto set = ifOp.getIntegerSet(); + SmallVector operands(ifOp.getOperands()); + + canonicalizeSetAndOperands(&set, &operands); + + // Any canonicalization change always leads to either a reduction in the + // number of operands or a change in the number of symbolic operands + // (promotion of dims to symbols). + if (operands.size() < ifOp.getIntegerSet().getNumInputs() || + set.getNumSymbols() > ifOp.getIntegerSet().getNumSymbols()) { + ifOp.setConditional(set, operands); + rewriter.updatedRootInPlace(ifOp); + return matchSuccess(); + } + + return matchFailure(); + } +}; +} // end anonymous namespace + +void AffineIfOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // AffineLoadOp //===----------------------------------------------------------------------===// diff --git a/third_party/mlir/lib/IR/IntegerSet.cpp b/third_party/mlir/lib/IR/IntegerSet.cpp index 74a1297dcdd..139ca504b58 100644 --- a/third_party/mlir/lib/IR/IntegerSet.cpp +++ b/third_party/mlir/lib/IR/IntegerSet.cpp @@ -24,7 +24,7 @@ using namespace mlir::detail; unsigned IntegerSet::getNumDims() const { return set->dimCount; } unsigned IntegerSet::getNumSymbols() const { return set->symbolCount; } -unsigned IntegerSet::getNumOperands() const { +unsigned IntegerSet::getNumInputs() const { return set->dimCount + set->symbolCount; } @@ -70,3 +70,23 @@ bool IntegerSet::isEq(unsigned idx) const { return getEqFlags()[idx]; } MLIRContext *IntegerSet::getContext() const { return getConstraint(0).getContext(); } + +/// Walk all of the AffineExpr's in this set. Each node in an expression +/// tree is visited in postorder. +void IntegerSet::walkExprs( + llvm::function_ref callback) const { + for (auto expr : getConstraints()) + expr.walk(callback); +} + +IntegerSet IntegerSet::replaceDimsAndSymbols( + ArrayRef dimReplacements, ArrayRef symReplacements, + unsigned numResultDims, unsigned numResultSyms) { + SmallVector constraints; + constraints.reserve(getNumConstraints()); + for (auto cst : getConstraints()) + constraints.push_back( + cst.replaceDimsAndSymbols(dimReplacements, symReplacements)); + + return get(numResultDims, numResultSyms, constraints, getEqFlags()); +}