Integer set + operands / affine if op canonicalization
- turn canonicalizeMapAndOperands into a template that works on both sets and maps, and use it to introduce a utility to canonicalize an affine integer set and its operands - add pattern to canonicalize affine if op's. - rename IntegerSet::getNumOperands -> IntegerSet::getNumInputs to be consistent with AffineMap - add missing accessors for IntegerSet Doesn't need extensive testing since canonicalizeSetAndOperands just reuses canonicalizeMapAndOperands' logic, and the latter is tested on affine.apply map + operands; the new method works the same way on an integer set + operands of an affine if op for example. Signed-off-by: Uday Bondhugula <uday@polymagelabs.com> Closes #112 COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/112 from bondhugula:set-canonicalize eff72f23250b96fa7d9f5caff3877440f5de2cec PiperOrigin-RevId: 267532876
This commit is contained in:
parent
cf27972e28
commit
2ad042991c
third_party/mlir
include/mlir
lib
@ -522,6 +522,10 @@ bool isValidSymbol(Value *value);
|
||||
/// 2. drop unused dims and symbols from map
|
||||
void canonicalizeMapAndOperands(AffineMap *map,
|
||||
llvm::SmallVectorImpl<Value *> *operands);
|
||||
/// Canonicalizes an integer set the same way canonicalizeMapAndOperands does
|
||||
/// for affine maps.
|
||||
void canonicalizeSetAndOperands(IntegerSet *set,
|
||||
llvm::SmallVectorImpl<Value *> *operands);
|
||||
|
||||
/// Returns a composed AffineApplyOp by composing `map` and `operands` with
|
||||
/// other AffineApplyOps supplying those operands. The operands of the resulting
|
||||
|
@ -223,6 +223,11 @@ def AffineIfOp : Affine_Op<"if", [ImplicitAffineTerminator]> {
|
||||
IntegerSet getIntegerSet();
|
||||
void setIntegerSet(IntegerSet newSet);
|
||||
|
||||
/// Sets the integer set with its operands. The size of 'operands' must not
|
||||
/// exceed the current number of operands for this instance, as the operands
|
||||
/// list of AffineIf is not resizable.
|
||||
void setConditional(IntegerSet set, ArrayRef<Value *> operands);
|
||||
|
||||
OpBuilder getThenBodyBuilder() {
|
||||
assert(!thenRegion().empty() && "Unexpected empty 'then' region.");
|
||||
Block &body = thenRegion().front();
|
||||
@ -234,6 +239,8 @@ def AffineIfOp : Affine_Op<"if", [ImplicitAffineTerminator]> {
|
||||
return OpBuilder(&body, std::prev(body.end()));
|
||||
}
|
||||
}];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def AffineTerminatorOp :
|
||||
|
16
third_party/mlir/include/mlir/IR/IntegerSet.h
vendored
16
third_party/mlir/include/mlir/IR/IntegerSet.h
vendored
@ -72,12 +72,22 @@ public:
|
||||
/// Returns true if this is the canonical integer set.
|
||||
bool isEmptyIntegerSet() const;
|
||||
|
||||
/// This method substitutes any uses of dimensions and symbols (e.g.
|
||||
/// dim#0 with dimReplacements[0]) in subexpressions and returns the modified
|
||||
/// integer set. Because this can be used to eliminate dims and
|
||||
/// symbols, the client needs to specify the number of dims and symbols in
|
||||
/// the result. The returned map always has the same number of results.
|
||||
IntegerSet replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,
|
||||
ArrayRef<AffineExpr> symReplacements,
|
||||
unsigned numResultDims,
|
||||
unsigned numResultSyms);
|
||||
|
||||
explicit operator bool() { return set; }
|
||||
bool operator==(IntegerSet other) const { return set == other.set; }
|
||||
|
||||
unsigned getNumDims() const;
|
||||
unsigned getNumSymbols() const;
|
||||
unsigned getNumOperands() const;
|
||||
unsigned getNumInputs() const;
|
||||
unsigned getNumConstraints() const;
|
||||
unsigned getNumEqualities() const;
|
||||
unsigned getNumInequalities() const;
|
||||
@ -96,6 +106,10 @@ public:
|
||||
|
||||
MLIRContext *getContext() const;
|
||||
|
||||
/// Walk all of the AffineExpr's in this set's constraints. Each node in an
|
||||
/// expression tree is visited in postorder.
|
||||
void walkExprs(llvm::function_ref<void(AffineExpr)> callback) const;
|
||||
|
||||
void print(raw_ostream &os) const;
|
||||
void dump() const;
|
||||
|
||||
|
@ -308,7 +308,7 @@ std::unique_ptr<FlatAffineConstraints> FlatAffineConstraints::clone() const {
|
||||
|
||||
// Construct from an IntegerSet.
|
||||
FlatAffineConstraints::FlatAffineConstraints(IntegerSet set)
|
||||
: numReservedCols(set.getNumOperands() + 1),
|
||||
: numReservedCols(set.getNumInputs() + 1),
|
||||
numIds(set.getNumDims() + set.getNumSymbols()), numDims(set.getNumDims()),
|
||||
numSymbols(set.getNumSymbols()) {
|
||||
equalities.reserve(set.getNumEqualities() * numReservedCols);
|
||||
|
126
third_party/mlir/lib/Dialect/AffineOps/AffineOps.cpp
vendored
126
third_party/mlir/lib/Dialect/AffineOps/AffineOps.cpp
vendored
@ -620,26 +620,27 @@ AffineApplyOp mlir::makeComposedAffineApply(OpBuilder &b, Location loc,
|
||||
|
||||
// A symbol may appear as a dim in affine.apply operations. This function
|
||||
// canonicalizes dims that are valid symbols into actual symbols.
|
||||
template <class MapOrSet>
|
||||
static void
|
||||
canonicalizePromotedSymbols(AffineMap *map,
|
||||
canonicalizePromotedSymbols(MapOrSet *mapOrSet,
|
||||
llvm::SmallVectorImpl<Value *> *operands) {
|
||||
if (!map || operands->empty())
|
||||
if (!mapOrSet || operands->empty())
|
||||
return;
|
||||
|
||||
assert(map->getNumInputs() == operands->size() &&
|
||||
"map inputs must match number of operands");
|
||||
assert(mapOrSet->getNumInputs() == operands->size() &&
|
||||
"map/set inputs must match number of operands");
|
||||
|
||||
auto *context = map->getContext();
|
||||
auto *context = mapOrSet->getContext();
|
||||
SmallVector<Value *, 8> resultOperands;
|
||||
resultOperands.reserve(operands->size());
|
||||
SmallVector<Value *, 8> remappedSymbols;
|
||||
remappedSymbols.reserve(operands->size());
|
||||
unsigned nextDim = 0;
|
||||
unsigned nextSym = 0;
|
||||
unsigned oldNumSyms = map->getNumSymbols();
|
||||
SmallVector<AffineExpr, 8> dimRemapping(map->getNumDims());
|
||||
for (unsigned i = 0, e = map->getNumInputs(); i != e; ++i) {
|
||||
if (i < map->getNumDims()) {
|
||||
unsigned oldNumSyms = mapOrSet->getNumSymbols();
|
||||
SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
|
||||
for (unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) {
|
||||
if (i < mapOrSet->getNumDims()) {
|
||||
if (isValidSymbol((*operands)[i])) {
|
||||
// This is a valid symbol that appears as a dim, canonicalize it.
|
||||
dimRemapping[i] = getAffineSymbolExpr(oldNumSyms + nextSym++, context);
|
||||
@ -655,42 +656,49 @@ canonicalizePromotedSymbols(AffineMap *map,
|
||||
|
||||
resultOperands.append(remappedSymbols.begin(), remappedSymbols.end());
|
||||
*operands = resultOperands;
|
||||
*map = map->replaceDimsAndSymbols(dimRemapping, {}, nextDim,
|
||||
oldNumSyms + nextSym);
|
||||
*mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, {}, nextDim,
|
||||
oldNumSyms + nextSym);
|
||||
|
||||
assert(map->getNumInputs() == operands->size() &&
|
||||
"map inputs must match number of operands");
|
||||
assert(mapOrSet->getNumInputs() == operands->size() &&
|
||||
"map/set inputs must match number of operands");
|
||||
}
|
||||
|
||||
void mlir::canonicalizeMapAndOperands(
|
||||
AffineMap *map, llvm::SmallVectorImpl<Value *> *operands) {
|
||||
if (!map || operands->empty())
|
||||
// Works for either an affine map or an integer set.
|
||||
template <class MapOrSet>
|
||||
static void
|
||||
canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet,
|
||||
llvm::SmallVectorImpl<Value *> *operands) {
|
||||
static_assert(std::is_same<MapOrSet, AffineMap>::value ||
|
||||
std::is_same<MapOrSet, IntegerSet>::value,
|
||||
"Argument must be either of AffineMap or IntegerSet type");
|
||||
|
||||
if (!mapOrSet || operands->empty())
|
||||
return;
|
||||
|
||||
assert(map->getNumInputs() == operands->size() &&
|
||||
"map inputs must match number of operands");
|
||||
assert(mapOrSet->getNumInputs() == operands->size() &&
|
||||
"map/set inputs must match number of operands");
|
||||
|
||||
canonicalizePromotedSymbols(map, operands);
|
||||
canonicalizePromotedSymbols<MapOrSet>(mapOrSet, operands);
|
||||
|
||||
// Check to see what dims are used.
|
||||
llvm::SmallBitVector usedDims(map->getNumDims());
|
||||
llvm::SmallBitVector usedSyms(map->getNumSymbols());
|
||||
map->walkExprs([&](AffineExpr expr) {
|
||||
llvm::SmallBitVector usedDims(mapOrSet->getNumDims());
|
||||
llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols());
|
||||
mapOrSet->walkExprs([&](AffineExpr expr) {
|
||||
if (auto dimExpr = expr.dyn_cast<AffineDimExpr>())
|
||||
usedDims[dimExpr.getPosition()] = true;
|
||||
else if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>())
|
||||
usedSyms[symExpr.getPosition()] = true;
|
||||
});
|
||||
|
||||
auto *context = map->getContext();
|
||||
auto *context = mapOrSet->getContext();
|
||||
|
||||
SmallVector<Value *, 8> resultOperands;
|
||||
resultOperands.reserve(operands->size());
|
||||
|
||||
llvm::SmallDenseMap<Value *, AffineExpr, 8> seenDims;
|
||||
SmallVector<AffineExpr, 8> dimRemapping(map->getNumDims());
|
||||
SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
|
||||
unsigned nextDim = 0;
|
||||
for (unsigned i = 0, e = map->getNumDims(); i != e; ++i) {
|
||||
for (unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) {
|
||||
if (usedDims[i]) {
|
||||
// Remap dim positions for duplicate operands.
|
||||
auto it = seenDims.find((*operands)[i]);
|
||||
@ -704,37 +712,47 @@ void mlir::canonicalizeMapAndOperands(
|
||||
}
|
||||
}
|
||||
llvm::SmallDenseMap<Value *, AffineExpr, 8> seenSymbols;
|
||||
SmallVector<AffineExpr, 8> symRemapping(map->getNumSymbols());
|
||||
SmallVector<AffineExpr, 8> symRemapping(mapOrSet->getNumSymbols());
|
||||
unsigned nextSym = 0;
|
||||
for (unsigned i = 0, e = map->getNumSymbols(); i != e; ++i) {
|
||||
for (unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) {
|
||||
if (!usedSyms[i])
|
||||
continue;
|
||||
// Handle constant operands (only needed for symbolic operands since
|
||||
// constant operands in dimensional positions would have already been
|
||||
// promoted to symbolic positions above).
|
||||
IntegerAttr operandCst;
|
||||
if (matchPattern((*operands)[i + map->getNumDims()],
|
||||
if (matchPattern((*operands)[i + mapOrSet->getNumDims()],
|
||||
m_Constant(&operandCst))) {
|
||||
symRemapping[i] =
|
||||
getAffineConstantExpr(operandCst.getValue().getSExtValue(), context);
|
||||
continue;
|
||||
}
|
||||
// Remap symbol positions for duplicate operands.
|
||||
auto it = seenSymbols.find((*operands)[i + map->getNumDims()]);
|
||||
auto it = seenSymbols.find((*operands)[i + mapOrSet->getNumDims()]);
|
||||
if (it == seenSymbols.end()) {
|
||||
symRemapping[i] = getAffineSymbolExpr(nextSym++, context);
|
||||
resultOperands.push_back((*operands)[i + map->getNumDims()]);
|
||||
seenSymbols.insert(
|
||||
std::make_pair((*operands)[i + map->getNumDims()], symRemapping[i]));
|
||||
resultOperands.push_back((*operands)[i + mapOrSet->getNumDims()]);
|
||||
seenSymbols.insert(std::make_pair((*operands)[i + mapOrSet->getNumDims()],
|
||||
symRemapping[i]));
|
||||
} else {
|
||||
symRemapping[i] = it->second;
|
||||
}
|
||||
}
|
||||
*map =
|
||||
map->replaceDimsAndSymbols(dimRemapping, symRemapping, nextDim, nextSym);
|
||||
*mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, symRemapping,
|
||||
nextDim, nextSym);
|
||||
*operands = resultOperands;
|
||||
}
|
||||
|
||||
void mlir::canonicalizeMapAndOperands(
|
||||
AffineMap *map, llvm::SmallVectorImpl<Value *> *operands) {
|
||||
canonicalizeMapOrSetAndOperands<AffineMap>(map, operands);
|
||||
}
|
||||
|
||||
void mlir::canonicalizeSetAndOperands(
|
||||
IntegerSet *set, llvm::SmallVectorImpl<Value *> *operands) {
|
||||
canonicalizeMapOrSetAndOperands<IntegerSet>(set, operands);
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// Simplify AffineApply operations.
|
||||
///
|
||||
@ -1540,7 +1558,7 @@ static LogicalResult verify(AffineIfOp op) {
|
||||
|
||||
// Verify that there are enough operands for the condition.
|
||||
IntegerSet condition = conditionAttr.getValue();
|
||||
if (op.getNumOperands() != condition.getNumOperands())
|
||||
if (op.getNumOperands() != condition.getNumInputs())
|
||||
return op.emitOpError(
|
||||
"operand count and condition integer set dimension and "
|
||||
"symbol count must match");
|
||||
@ -1639,6 +1657,44 @@ void AffineIfOp::setIntegerSet(IntegerSet newSet) {
|
||||
setAttr(getConditionAttrName(), IntegerSetAttr::get(newSet));
|
||||
}
|
||||
|
||||
void AffineIfOp::setConditional(IntegerSet set, ArrayRef<Value *> operands) {
|
||||
setIntegerSet(set);
|
||||
getOperation()->setOperands(operands);
|
||||
}
|
||||
|
||||
namespace {
|
||||
// This is a pattern to canonicalize an affine if op's conditional (integer
|
||||
// set + operands).
|
||||
struct AffineIfOpCanonicalizer : public OpRewritePattern<AffineIfOp> {
|
||||
using OpRewritePattern<AffineIfOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(AffineIfOp ifOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto set = ifOp.getIntegerSet();
|
||||
SmallVector<Value *, 4> operands(ifOp.getOperands());
|
||||
|
||||
canonicalizeSetAndOperands(&set, &operands);
|
||||
|
||||
// Any canonicalization change always leads to either a reduction in the
|
||||
// number of operands or a change in the number of symbolic operands
|
||||
// (promotion of dims to symbols).
|
||||
if (operands.size() < ifOp.getIntegerSet().getNumInputs() ||
|
||||
set.getNumSymbols() > ifOp.getIntegerSet().getNumSymbols()) {
|
||||
ifOp.setConditional(set, operands);
|
||||
rewriter.updatedRootInPlace(ifOp);
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
return matchFailure();
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
void AffineIfOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.insert<AffineIfOpCanonicalizer>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AffineLoadOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
22
third_party/mlir/lib/IR/IntegerSet.cpp
vendored
22
third_party/mlir/lib/IR/IntegerSet.cpp
vendored
@ -24,7 +24,7 @@ using namespace mlir::detail;
|
||||
|
||||
unsigned IntegerSet::getNumDims() const { return set->dimCount; }
|
||||
unsigned IntegerSet::getNumSymbols() const { return set->symbolCount; }
|
||||
unsigned IntegerSet::getNumOperands() const {
|
||||
unsigned IntegerSet::getNumInputs() const {
|
||||
return set->dimCount + set->symbolCount;
|
||||
}
|
||||
|
||||
@ -70,3 +70,23 @@ bool IntegerSet::isEq(unsigned idx) const { return getEqFlags()[idx]; }
|
||||
MLIRContext *IntegerSet::getContext() const {
|
||||
return getConstraint(0).getContext();
|
||||
}
|
||||
|
||||
/// Walk all of the AffineExpr's in this set. Each node in an expression
|
||||
/// tree is visited in postorder.
|
||||
void IntegerSet::walkExprs(
|
||||
llvm::function_ref<void(AffineExpr)> callback) const {
|
||||
for (auto expr : getConstraints())
|
||||
expr.walk(callback);
|
||||
}
|
||||
|
||||
IntegerSet IntegerSet::replaceDimsAndSymbols(
|
||||
ArrayRef<AffineExpr> dimReplacements, ArrayRef<AffineExpr> symReplacements,
|
||||
unsigned numResultDims, unsigned numResultSyms) {
|
||||
SmallVector<AffineExpr, 8> constraints;
|
||||
constraints.reserve(getNumConstraints());
|
||||
for (auto cst : getConstraints())
|
||||
constraints.push_back(
|
||||
cst.replaceDimsAndSymbols(dimReplacements, symReplacements));
|
||||
|
||||
return get(numResultDims, numResultSyms, constraints, getEqFlags());
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user