Add canonicalizer for ViewOp which folds constants into the ViewOp memref shape and layout map strides and offset.
PiperOrigin-RevId: 279088023 Change-Id: I36794dc276ed15c5b735603981a5d08b2ec5f465
This commit is contained in:
parent
af20287250
commit
700263d02a
@ -1192,7 +1192,8 @@ def ViewOp : Std_Op<"view"> {
|
|||||||
operand_begin() + 1 + getType().getNumDynamicDims()};
|
operand_begin() + 1 + getType().getNumDynamicDims()};
|
||||||
}
|
}
|
||||||
}];
|
}];
|
||||||
// TODO(andydavis) Add canonicalizer to fold constants into shape and map.
|
|
||||||
|
let hasCanonicalizer = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def XOrOp : IntArithmeticOp<"xor", [Commutative]> {
|
def XOrOp : IntArithmeticOp<"xor", [Commutative]> {
|
||||||
|
112
third_party/mlir/lib/Dialect/StandardOps/Ops.cpp
vendored
112
third_party/mlir/lib/Dialect/StandardOps/Ops.cpp
vendored
@ -2419,6 +2419,118 @@ static LogicalResult verify(ViewOp op) {
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
|
||||||
|
using OpRewritePattern<ViewOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
PatternMatchResult matchAndRewrite(ViewOp viewOp,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
// Return if none of the operands are constants.
|
||||||
|
if (llvm::none_of(viewOp.getOperands(), [](Value *operand) {
|
||||||
|
return matchPattern(operand, m_ConstantIndex());
|
||||||
|
}))
|
||||||
|
return matchFailure();
|
||||||
|
|
||||||
|
// Get result memref type.
|
||||||
|
auto memrefType = viewOp.getType();
|
||||||
|
if (memrefType.getAffineMaps().size() != 1)
|
||||||
|
return matchFailure();
|
||||||
|
auto map = memrefType.getAffineMaps()[0];
|
||||||
|
|
||||||
|
// Fold any dynamic dim operands which are produced by a constant.
|
||||||
|
SmallVector<int64_t, 4> newShapeConstants;
|
||||||
|
newShapeConstants.reserve(memrefType.getRank());
|
||||||
|
SmallVector<Value *, 4> newOperands;
|
||||||
|
SmallVector<Value *, 4> droppedOperands;
|
||||||
|
|
||||||
|
unsigned dynamicDimPos = 1;
|
||||||
|
unsigned rank = memrefType.getRank();
|
||||||
|
for (unsigned dim = 0, e = rank; dim < e; ++dim) {
|
||||||
|
int64_t dimSize = memrefType.getDimSize(dim);
|
||||||
|
// If this is already static dimension, keep it.
|
||||||
|
if (!ShapedType::isDynamic(dimSize)) {
|
||||||
|
newShapeConstants.push_back(dimSize);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto *defOp = viewOp.getOperand(dynamicDimPos)->getDefiningOp();
|
||||||
|
if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
|
||||||
|
// Dynamic shape dimension will be folded.
|
||||||
|
newShapeConstants.push_back(constantIndexOp.getValue());
|
||||||
|
// Record to check for zero uses later below.
|
||||||
|
droppedOperands.push_back(constantIndexOp);
|
||||||
|
} else {
|
||||||
|
// Dynamic shape dimension not folded; copy operand from old memref.
|
||||||
|
newShapeConstants.push_back(dimSize);
|
||||||
|
newOperands.push_back(viewOp.getOperand(dynamicDimPos));
|
||||||
|
}
|
||||||
|
dynamicDimPos++;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get offset from old memref view type 'memRefType'.
|
||||||
|
int64_t oldOffset;
|
||||||
|
llvm::SmallVector<int64_t, 4> oldStrides;
|
||||||
|
if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset)))
|
||||||
|
return matchFailure();
|
||||||
|
|
||||||
|
// Fold dynamic offset operand if it is produced by a constant.
|
||||||
|
auto *dynamicOffset = viewOp.getDynamicOffset();
|
||||||
|
int64_t newOffset = oldOffset;
|
||||||
|
unsigned dynamicOffsetOperandCount = 0;
|
||||||
|
if (dynamicOffset != nullptr) {
|
||||||
|
auto *defOp = dynamicOffset->getDefiningOp();
|
||||||
|
if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
|
||||||
|
// Dynamic offset will be folded into the map.
|
||||||
|
newOffset = constantIndexOp.getValue();
|
||||||
|
droppedOperands.push_back(dynamicOffset);
|
||||||
|
} else {
|
||||||
|
// Unable to fold dynamic offset. Add it to 'newOperands' list.
|
||||||
|
newOperands.push_back(dynamicOffset);
|
||||||
|
dynamicOffsetOperandCount = 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute new strides based on 'newShapeConstants'.
|
||||||
|
SmallVector<int64_t, 4> newStrides(rank);
|
||||||
|
newStrides[rank - 1] = 1;
|
||||||
|
bool dynamicStrides = false;
|
||||||
|
for (int i = rank - 2; i >= 0; --i) {
|
||||||
|
if (ShapedType::isDynamic(newShapeConstants[i + 1]))
|
||||||
|
dynamicStrides = true;
|
||||||
|
if (dynamicStrides)
|
||||||
|
newStrides[i] = MemRefType::getDynamicStrideOrOffset();
|
||||||
|
else
|
||||||
|
newStrides[i] = newShapeConstants[i + 1] * newStrides[i + 1];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Regenerate strided layout map with 'newStrides' and 'newOffset'.
|
||||||
|
map = makeStridedLinearLayoutMap(newStrides, newOffset,
|
||||||
|
rewriter.getContext());
|
||||||
|
|
||||||
|
// Create new memref type with constant folded dims and/or offset/strides.
|
||||||
|
auto newMemRefType =
|
||||||
|
MemRefType::get(newShapeConstants, memrefType.getElementType(), {map},
|
||||||
|
memrefType.getMemorySpace());
|
||||||
|
assert(static_cast<int64_t>(newOperands.size()) ==
|
||||||
|
dynamicOffsetOperandCount + newMemRefType.getNumDynamicDims());
|
||||||
|
|
||||||
|
// Create new ViewOp.
|
||||||
|
auto newShapeCastOp = rewriter.create<ViewOp>(
|
||||||
|
viewOp.getLoc(), newMemRefType, viewOp.getOperand(0), newOperands);
|
||||||
|
// Insert a cast so we have the same type as the old memref type.
|
||||||
|
rewriter.replaceOpWithNewOp<MemRefCastOp>(droppedOperands, viewOp,
|
||||||
|
newShapeCastOp, viewOp.getType());
|
||||||
|
return matchSuccess();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // end anonymous namespace
|
||||||
|
|
||||||
|
void ViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||||
|
MLIRContext *context) {
|
||||||
|
results.insert<ViewOpShapeFolder>(context);
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// ZeroExtendIOp
|
// ZeroExtendIOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
Loading…
x
Reference in New Issue
Block a user