From 88a3f5c15335ed7ba0f665976bf6087b5e883f17 Mon Sep 17 00:00:00 2001 From: Stephan Herhut Date: Mon, 18 Nov 2019 04:31:02 -0800 Subject: [PATCH] Implement folding of pattern dim(subview(_)[...][s1, ..., sn][...], i) -> si. PiperOrigin-RevId: 281042016 Change-Id: Idf43307b94380915c5cb813f5ff14fa8bf751977 --- third_party/mlir/lib/Dialect/StandardOps/Ops.cpp | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/third_party/mlir/lib/Dialect/StandardOps/Ops.cpp b/third_party/mlir/lib/Dialect/StandardOps/Ops.cpp index c4abee3858e..e38ce065647 100644 --- a/third_party/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/third_party/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -1347,7 +1347,7 @@ static LogicalResult verify(DimOp op) { OpFoldResult DimOp::fold(ArrayRef operands) { // Constant fold dim when the size along the index referred to is a constant. - auto opType = getOperand()->getType(); + auto opType = memrefOrTensor()->getType(); int64_t indexSize = -1; if (auto tensorType = opType.dyn_cast()) indexSize = tensorType.getShape()[getIndex()]; @@ -1357,6 +1357,14 @@ OpFoldResult DimOp::fold(ArrayRef operands) { if (indexSize >= 0) return IntegerAttr::get(IndexType::get(getContext()), indexSize); + // Fold dim to the size argument of a SubViewOp. + auto memref = memrefOrTensor()->getDefiningOp(); + if (auto subview = dyn_cast_or_null(memref)) { + auto sizes = subview.getDynamicSizes(); + if (!sizes.empty()) + return *(sizes.begin() + getIndex()); + } + return {}; }