Implement folding of pattern dim(subview(_)[...][s1, ..., sn][...], i) -> si.

PiperOrigin-RevId: 281042016
Change-Id: Idf43307b94380915c5cb813f5ff14fa8bf751977
This commit is contained in:
Stephan Herhut 2019-11-18 04:31:02 -08:00 committed by TensorFlower Gardener
parent 89dffab453
commit 88a3f5c153

View File

@ -1347,7 +1347,7 @@ static LogicalResult verify(DimOp op) {
OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
// Constant fold dim when the size along the index referred to is a constant.
auto opType = getOperand()->getType();
auto opType = memrefOrTensor()->getType();
int64_t indexSize = -1;
if (auto tensorType = opType.dyn_cast<RankedTensorType>())
indexSize = tensorType.getShape()[getIndex()];
@ -1357,6 +1357,14 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
if (indexSize >= 0)
return IntegerAttr::get(IndexType::get(getContext()), indexSize);
// Fold dim to the size argument of a SubViewOp.
auto memref = memrefOrTensor()->getDefiningOp();
if (auto subview = dyn_cast_or_null<SubViewOp>(memref)) {
auto sizes = subview.getDynamicSizes();
if (!sizes.empty())
return *(sizes.begin() + getIndex());
}
return {};
}