diff --git a/third_party/mlir/include/mlir/Dialect/AffineOps/AffineOps.h b/third_party/mlir/include/mlir/Dialect/AffineOps/AffineOps.h
index a6af20eca0b..03b945c0a5f 100644
--- a/third_party/mlir/include/mlir/Dialect/AffineOps/AffineOps.h
+++ b/third_party/mlir/include/mlir/Dialect/AffineOps/AffineOps.h
@@ -522,6 +522,10 @@ bool isValidSymbol(Value *value);
 /// 2. drop unused dims and symbols from map
 void canonicalizeMapAndOperands(AffineMap *map,
                                 llvm::SmallVectorImpl<Value *> *operands);
+/// Canonicalizes an integer set the same way canonicalizeMapAndOperands does
+/// for affine maps.
+void canonicalizeSetAndOperands(IntegerSet *set,
+                                llvm::SmallVectorImpl<Value *> *operands);
 
 /// Returns a composed AffineApplyOp by composing `map` and `operands` with
 /// other AffineApplyOps supplying those operands. The operands of the resulting
diff --git a/third_party/mlir/include/mlir/Dialect/AffineOps/AffineOps.td b/third_party/mlir/include/mlir/Dialect/AffineOps/AffineOps.td
index 237692c04a7..4961ce8ee95 100644
--- a/third_party/mlir/include/mlir/Dialect/AffineOps/AffineOps.td
+++ b/third_party/mlir/include/mlir/Dialect/AffineOps/AffineOps.td
@@ -223,6 +223,11 @@ def AffineIfOp : Affine_Op<"if", [ImplicitAffineTerminator]> {
     IntegerSet getIntegerSet();
     void setIntegerSet(IntegerSet newSet);
 
+    /// Sets the integer set with its operands. The size of 'operands' must not
+    /// exceed the current number of operands for this instance, as the operands
+    /// list of AffineIf is not resizable.
+    void setConditional(IntegerSet set, ArrayRef<Value *> operands);
+
     OpBuilder getThenBodyBuilder() {
       assert(!thenRegion().empty() && "Unexpected empty 'then' region.");
       Block &body = thenRegion().front();
@@ -234,6 +239,8 @@ def AffineIfOp : Affine_Op<"if", [ImplicitAffineTerminator]> {
       return OpBuilder(&body, std::prev(body.end()));
     }
   }];
+
+  let hasCanonicalizer = 1;
 }
 
 def AffineTerminatorOp :
diff --git a/third_party/mlir/include/mlir/IR/IntegerSet.h b/third_party/mlir/include/mlir/IR/IntegerSet.h
index b7662f095a5..e989f91bafd 100644
--- a/third_party/mlir/include/mlir/IR/IntegerSet.h
+++ b/third_party/mlir/include/mlir/IR/IntegerSet.h
@@ -72,12 +72,22 @@ public:
   /// Returns true if this is the canonical integer set.
   bool isEmptyIntegerSet() const;
 
+  /// This method substitutes any uses of dimensions and symbols (e.g.
+  /// dim#0 with dimReplacements[0]) in subexpressions and returns the modified
+  /// integer set.  Because this can be used to eliminate dims and
+  /// symbols, the client needs to specify the number of dims and symbols in
+  /// the result.  The returned map always has the same number of results.
+  IntegerSet replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,
+                                   ArrayRef<AffineExpr> symReplacements,
+                                   unsigned numResultDims,
+                                   unsigned numResultSyms);
+
   explicit operator bool() { return set; }
   bool operator==(IntegerSet other) const { return set == other.set; }
 
   unsigned getNumDims() const;
   unsigned getNumSymbols() const;
-  unsigned getNumOperands() const;
+  unsigned getNumInputs() const;
   unsigned getNumConstraints() const;
   unsigned getNumEqualities() const;
   unsigned getNumInequalities() const;
@@ -96,6 +106,10 @@ public:
 
   MLIRContext *getContext() const;
 
+  /// Walk all of the AffineExpr's in this set's constraints. Each node in an
+  /// expression tree is visited in postorder.
+  void walkExprs(llvm::function_ref<void(AffineExpr)> callback) const;
+
   void print(raw_ostream &os) const;
   void dump() const;
 
diff --git a/third_party/mlir/lib/Analysis/AffineStructures.cpp b/third_party/mlir/lib/Analysis/AffineStructures.cpp
index f660fff7df6..2804ac68b4b 100644
--- a/third_party/mlir/lib/Analysis/AffineStructures.cpp
+++ b/third_party/mlir/lib/Analysis/AffineStructures.cpp
@@ -308,7 +308,7 @@ std::unique_ptr<FlatAffineConstraints> FlatAffineConstraints::clone() const {
 
 // Construct from an IntegerSet.
 FlatAffineConstraints::FlatAffineConstraints(IntegerSet set)
-    : numReservedCols(set.getNumOperands() + 1),
+    : numReservedCols(set.getNumInputs() + 1),
       numIds(set.getNumDims() + set.getNumSymbols()), numDims(set.getNumDims()),
       numSymbols(set.getNumSymbols()) {
   equalities.reserve(set.getNumEqualities() * numReservedCols);
diff --git a/third_party/mlir/lib/Dialect/AffineOps/AffineOps.cpp b/third_party/mlir/lib/Dialect/AffineOps/AffineOps.cpp
index c6abc05b966..2161ae0d164 100644
--- a/third_party/mlir/lib/Dialect/AffineOps/AffineOps.cpp
+++ b/third_party/mlir/lib/Dialect/AffineOps/AffineOps.cpp
@@ -620,26 +620,27 @@ AffineApplyOp mlir::makeComposedAffineApply(OpBuilder &b, Location loc,
 
 // A symbol may appear as a dim in affine.apply operations. This function
 // canonicalizes dims that are valid symbols into actual symbols.
+template <class MapOrSet>
 static void
-canonicalizePromotedSymbols(AffineMap *map,
+canonicalizePromotedSymbols(MapOrSet *mapOrSet,
                             llvm::SmallVectorImpl<Value *> *operands) {
-  if (!map || operands->empty())
+  if (!mapOrSet || operands->empty())
     return;
 
-  assert(map->getNumInputs() == operands->size() &&
-         "map inputs must match number of operands");
+  assert(mapOrSet->getNumInputs() == operands->size() &&
+         "map/set inputs must match number of operands");
 
-  auto *context = map->getContext();
+  auto *context = mapOrSet->getContext();
   SmallVector<Value *, 8> resultOperands;
   resultOperands.reserve(operands->size());
   SmallVector<Value *, 8> remappedSymbols;
   remappedSymbols.reserve(operands->size());
   unsigned nextDim = 0;
   unsigned nextSym = 0;
-  unsigned oldNumSyms = map->getNumSymbols();
-  SmallVector<AffineExpr, 8> dimRemapping(map->getNumDims());
-  for (unsigned i = 0, e = map->getNumInputs(); i != e; ++i) {
-    if (i < map->getNumDims()) {
+  unsigned oldNumSyms = mapOrSet->getNumSymbols();
+  SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
+  for (unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) {
+    if (i < mapOrSet->getNumDims()) {
       if (isValidSymbol((*operands)[i])) {
         // This is a valid symbol that appears as a dim, canonicalize it.
         dimRemapping[i] = getAffineSymbolExpr(oldNumSyms + nextSym++, context);
@@ -655,42 +656,49 @@ canonicalizePromotedSymbols(AffineMap *map,
 
   resultOperands.append(remappedSymbols.begin(), remappedSymbols.end());
   *operands = resultOperands;
-  *map = map->replaceDimsAndSymbols(dimRemapping, {}, nextDim,
-                                    oldNumSyms + nextSym);
+  *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, {}, nextDim,
+                                              oldNumSyms + nextSym);
 
-  assert(map->getNumInputs() == operands->size() &&
-         "map inputs must match number of operands");
+  assert(mapOrSet->getNumInputs() == operands->size() &&
+         "map/set inputs must match number of operands");
 }
 
-void mlir::canonicalizeMapAndOperands(
-    AffineMap *map, llvm::SmallVectorImpl<Value *> *operands) {
-  if (!map || operands->empty())
+// Works for either an affine map or an integer set.
+template <class MapOrSet>
+static void
+canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet,
+                                llvm::SmallVectorImpl<Value *> *operands) {
+  static_assert(std::is_same<MapOrSet, AffineMap>::value ||
+                    std::is_same<MapOrSet, IntegerSet>::value,
+                "Argument must be either of AffineMap or IntegerSet type");
+
+  if (!mapOrSet || operands->empty())
     return;
 
-  assert(map->getNumInputs() == operands->size() &&
-         "map inputs must match number of operands");
+  assert(mapOrSet->getNumInputs() == operands->size() &&
+         "map/set inputs must match number of operands");
 
-  canonicalizePromotedSymbols(map, operands);
+  canonicalizePromotedSymbols<MapOrSet>(mapOrSet, operands);
 
   // Check to see what dims are used.
-  llvm::SmallBitVector usedDims(map->getNumDims());
-  llvm::SmallBitVector usedSyms(map->getNumSymbols());
-  map->walkExprs([&](AffineExpr expr) {
+  llvm::SmallBitVector usedDims(mapOrSet->getNumDims());
+  llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols());
+  mapOrSet->walkExprs([&](AffineExpr expr) {
     if (auto dimExpr = expr.dyn_cast<AffineDimExpr>())
       usedDims[dimExpr.getPosition()] = true;
     else if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>())
       usedSyms[symExpr.getPosition()] = true;
   });
 
-  auto *context = map->getContext();
+  auto *context = mapOrSet->getContext();
 
   SmallVector<Value *, 8> resultOperands;
   resultOperands.reserve(operands->size());
 
   llvm::SmallDenseMap<Value *, AffineExpr, 8> seenDims;
-  SmallVector<AffineExpr, 8> dimRemapping(map->getNumDims());
+  SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
   unsigned nextDim = 0;
-  for (unsigned i = 0, e = map->getNumDims(); i != e; ++i) {
+  for (unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) {
     if (usedDims[i]) {
       // Remap dim positions for duplicate operands.
       auto it = seenDims.find((*operands)[i]);
@@ -704,37 +712,47 @@ void mlir::canonicalizeMapAndOperands(
     }
   }
   llvm::SmallDenseMap<Value *, AffineExpr, 8> seenSymbols;
-  SmallVector<AffineExpr, 8> symRemapping(map->getNumSymbols());
+  SmallVector<AffineExpr, 8> symRemapping(mapOrSet->getNumSymbols());
   unsigned nextSym = 0;
-  for (unsigned i = 0, e = map->getNumSymbols(); i != e; ++i) {
+  for (unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) {
     if (!usedSyms[i])
       continue;
     // Handle constant operands (only needed for symbolic operands since
     // constant operands in dimensional positions would have already been
     // promoted to symbolic positions above).
     IntegerAttr operandCst;
-    if (matchPattern((*operands)[i + map->getNumDims()],
+    if (matchPattern((*operands)[i + mapOrSet->getNumDims()],
                      m_Constant(&operandCst))) {
       symRemapping[i] =
           getAffineConstantExpr(operandCst.getValue().getSExtValue(), context);
       continue;
     }
     // Remap symbol positions for duplicate operands.
-    auto it = seenSymbols.find((*operands)[i + map->getNumDims()]);
+    auto it = seenSymbols.find((*operands)[i + mapOrSet->getNumDims()]);
     if (it == seenSymbols.end()) {
       symRemapping[i] = getAffineSymbolExpr(nextSym++, context);
-      resultOperands.push_back((*operands)[i + map->getNumDims()]);
-      seenSymbols.insert(
-          std::make_pair((*operands)[i + map->getNumDims()], symRemapping[i]));
+      resultOperands.push_back((*operands)[i + mapOrSet->getNumDims()]);
+      seenSymbols.insert(std::make_pair((*operands)[i + mapOrSet->getNumDims()],
+                                        symRemapping[i]));
     } else {
       symRemapping[i] = it->second;
     }
   }
-  *map =
-      map->replaceDimsAndSymbols(dimRemapping, symRemapping, nextDim, nextSym);
+  *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, symRemapping,
+                                              nextDim, nextSym);
   *operands = resultOperands;
 }
 
+void mlir::canonicalizeMapAndOperands(
+    AffineMap *map, llvm::SmallVectorImpl<Value *> *operands) {
+  canonicalizeMapOrSetAndOperands<AffineMap>(map, operands);
+}
+
+void mlir::canonicalizeSetAndOperands(
+    IntegerSet *set, llvm::SmallVectorImpl<Value *> *operands) {
+  canonicalizeMapOrSetAndOperands<IntegerSet>(set, operands);
+}
+
 namespace {
 /// Simplify AffineApply operations.
 ///
@@ -1540,7 +1558,7 @@ static LogicalResult verify(AffineIfOp op) {
 
   // Verify that there are enough operands for the condition.
   IntegerSet condition = conditionAttr.getValue();
-  if (op.getNumOperands() != condition.getNumOperands())
+  if (op.getNumOperands() != condition.getNumInputs())
     return op.emitOpError(
         "operand count and condition integer set dimension and "
         "symbol count must match");
@@ -1639,6 +1657,44 @@ void AffineIfOp::setIntegerSet(IntegerSet newSet) {
   setAttr(getConditionAttrName(), IntegerSetAttr::get(newSet));
 }
 
+void AffineIfOp::setConditional(IntegerSet set, ArrayRef<Value *> operands) {
+  setIntegerSet(set);
+  getOperation()->setOperands(operands);
+}
+
+namespace {
+// This is a pattern to canonicalize an affine if op's conditional (integer
+// set + operands).
+struct AffineIfOpCanonicalizer : public OpRewritePattern<AffineIfOp> {
+  using OpRewritePattern<AffineIfOp>::OpRewritePattern;
+
+  PatternMatchResult matchAndRewrite(AffineIfOp ifOp,
+                                     PatternRewriter &rewriter) const override {
+    auto set = ifOp.getIntegerSet();
+    SmallVector<Value *, 4> operands(ifOp.getOperands());
+
+    canonicalizeSetAndOperands(&set, &operands);
+
+    // Any canonicalization change always leads to either a reduction in the
+    // number of operands or a change in the number of symbolic operands
+    // (promotion of dims to symbols).
+    if (operands.size() < ifOp.getIntegerSet().getNumInputs() ||
+        set.getNumSymbols() > ifOp.getIntegerSet().getNumSymbols()) {
+      ifOp.setConditional(set, operands);
+      rewriter.updatedRootInPlace(ifOp);
+      return matchSuccess();
+    }
+
+    return matchFailure();
+  }
+};
+} // end anonymous namespace
+
+void AffineIfOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+                                             MLIRContext *context) {
+  results.insert<AffineIfOpCanonicalizer>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // AffineLoadOp
 //===----------------------------------------------------------------------===//
diff --git a/third_party/mlir/lib/IR/IntegerSet.cpp b/third_party/mlir/lib/IR/IntegerSet.cpp
index 74a1297dcdd..139ca504b58 100644
--- a/third_party/mlir/lib/IR/IntegerSet.cpp
+++ b/third_party/mlir/lib/IR/IntegerSet.cpp
@@ -24,7 +24,7 @@ using namespace mlir::detail;
 
 unsigned IntegerSet::getNumDims() const { return set->dimCount; }
 unsigned IntegerSet::getNumSymbols() const { return set->symbolCount; }
-unsigned IntegerSet::getNumOperands() const {
+unsigned IntegerSet::getNumInputs() const {
   return set->dimCount + set->symbolCount;
 }
 
@@ -70,3 +70,23 @@ bool IntegerSet::isEq(unsigned idx) const { return getEqFlags()[idx]; }
 MLIRContext *IntegerSet::getContext() const {
   return getConstraint(0).getContext();
 }
+
+/// Walk all of the AffineExpr's in this set. Each node in an expression
+/// tree is visited in postorder.
+void IntegerSet::walkExprs(
+    llvm::function_ref<void(AffineExpr)> callback) const {
+  for (auto expr : getConstraints())
+    expr.walk(callback);
+}
+
+IntegerSet IntegerSet::replaceDimsAndSymbols(
+    ArrayRef<AffineExpr> dimReplacements, ArrayRef<AffineExpr> symReplacements,
+    unsigned numResultDims, unsigned numResultSyms) {
+  SmallVector<AffineExpr, 8> constraints;
+  constraints.reserve(getNumConstraints());
+  for (auto cst : getConstraints())
+    constraints.push_back(
+        cst.replaceDimsAndSymbols(dimReplacements, symReplacements));
+
+  return get(numResultDims, numResultSyms, constraints, getEqFlags());
+}