diff --git a/third_party/mlir/lib/Dialect/StandardOps/Ops.cpp b/third_party/mlir/lib/Dialect/StandardOps/Ops.cpp index c4abee3858e..e38ce065647 100644 --- a/third_party/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/third_party/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -1347,7 +1347,7 @@ static LogicalResult verify(DimOp op) { OpFoldResult DimOp::fold(ArrayRef operands) { // Constant fold dim when the size along the index referred to is a constant. - auto opType = getOperand()->getType(); + auto opType = memrefOrTensor()->getType(); int64_t indexSize = -1; if (auto tensorType = opType.dyn_cast()) indexSize = tensorType.getShape()[getIndex()]; @@ -1357,6 +1357,14 @@ OpFoldResult DimOp::fold(ArrayRef operands) { if (indexSize >= 0) return IntegerAttr::get(IndexType::get(getContext()), indexSize); + // Fold dim to the size argument of a SubViewOp. + auto memref = memrefOrTensor()->getDefiningOp(); + if (auto subview = dyn_cast_or_null(memref)) { + auto sizes = subview.getDynamicSizes(); + if (!sizes.empty()) + return *(sizes.begin() + getIndex()); + } + return {}; }