Implement folding of pattern dim(subview(_)[...][s1, ..., sn][...], i) -> si.
PiperOrigin-RevId: 281042016 Change-Id: Idf43307b94380915c5cb813f5ff14fa8bf751977
This commit is contained in:
parent
89dffab453
commit
88a3f5c153
10
third_party/mlir/lib/Dialect/StandardOps/Ops.cpp
vendored
10
third_party/mlir/lib/Dialect/StandardOps/Ops.cpp
vendored
@ -1347,7 +1347,7 @@ static LogicalResult verify(DimOp op) {
|
||||
|
||||
OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
|
||||
// Constant fold dim when the size along the index referred to is a constant.
|
||||
auto opType = getOperand()->getType();
|
||||
auto opType = memrefOrTensor()->getType();
|
||||
int64_t indexSize = -1;
|
||||
if (auto tensorType = opType.dyn_cast<RankedTensorType>())
|
||||
indexSize = tensorType.getShape()[getIndex()];
|
||||
@ -1357,6 +1357,14 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (indexSize >= 0)
|
||||
return IntegerAttr::get(IndexType::get(getContext()), indexSize);
|
||||
|
||||
// Fold dim to the size argument of a SubViewOp.
|
||||
auto memref = memrefOrTensor()->getDefiningOp();
|
||||
if (auto subview = dyn_cast_or_null<SubViewOp>(memref)) {
|
||||
auto sizes = subview.getDynamicSizes();
|
||||
if (!sizes.empty())
|
||||
return *(sizes.begin() + getIndex());
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user