Add a CL option to Standard to LLVM lowering to use alloca instead of malloc/free.
In the future, a more configurable malloc and free interface should be used and exposed via extra parameters to the `createLowerToLLVMPass`. Until requirements are gathered, a simple CL flag allows generating code that runs successfully on hardware that cannot use the stdlib. PiperOrigin-RevId: 283833424 Change-Id: I56115a960e7d5a1fc14cabdc71dd3e33d9f6812c
This commit is contained in:
parent
15715cb2c8
commit
41228d7f14
@ -57,25 +57,40 @@ void populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter,
|
||||
OwningRewritePatternList &patterns);
|
||||
|
||||
/// Creates a pass to convert the Standard dialect into the LLVMIR dialect.
|
||||
std::unique_ptr<OpPassBase<ModuleOp>> createLowerToLLVMPass();
|
||||
/// By default stdlib malloc/free are used for allocating MemRef payloads.
|
||||
/// Specifying `useAlloca-true` emits stack allocations instead. In the future
|
||||
/// this may become an enum when we have concrete uses for other options.
|
||||
std::unique_ptr<OpPassBase<ModuleOp>>
|
||||
createLowerToLLVMPass(bool useAlloca = false);
|
||||
|
||||
/// Creates a pass to convert operations to the LLVMIR dialect. The conversion
|
||||
/// is defined by a list of patterns and a type converter that will be obtained
|
||||
/// during the pass using the provided callbacks.
|
||||
/// By default stdlib malloc/free are used for allocating MemRef payloads.
|
||||
/// Specifying `useAlloca-true` emits stack allocations instead. In the future
|
||||
/// this may become an enum when we have concrete uses for other options.
|
||||
std::unique_ptr<OpPassBase<ModuleOp>>
|
||||
createLowerToLLVMPass(LLVMPatternListFiller patternListFiller,
|
||||
LLVMTypeConverterMaker typeConverterMaker);
|
||||
LLVMTypeConverterMaker typeConverterMaker,
|
||||
bool useAlloca = false);
|
||||
|
||||
/// Creates a pass to convert operations to the LLVMIR dialect. The conversion
|
||||
/// is defined by a list of patterns obtained during the pass using the provided
|
||||
/// callback and an optional type conversion class, an instance is created
|
||||
/// during the pass.
|
||||
/// By default stdlib malloc/free are used for allocating MemRef payloads.
|
||||
/// Specifying `useAlloca-true` emits stack allocations instead. In the future
|
||||
/// this may become an enum when we have concrete uses for other options.
|
||||
template <typename TypeConverter = LLVMTypeConverter>
|
||||
std::unique_ptr<OpPassBase<ModuleOp>>
|
||||
createLowerToLLVMPass(LLVMPatternListFiller patternListFiller) {
|
||||
return createLowerToLLVMPass(patternListFiller, [](MLIRContext *context) {
|
||||
return std::make_unique<TypeConverter>(context);
|
||||
});
|
||||
createLowerToLLVMPass(LLVMPatternListFiller patternListFiller,
|
||||
bool useAlloca = false) {
|
||||
return createLowerToLLVMPass(
|
||||
patternListFiller,
|
||||
[](MLIRContext *context) {
|
||||
return std::make_unique<TypeConverter>(context);
|
||||
},
|
||||
useAlloca);
|
||||
}
|
||||
|
||||
namespace LLVM {
|
||||
|
@ -38,9 +38,20 @@
|
||||
#include "llvm/IR/DerivedTypes.h"
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "llvm/IR/Type.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
#define PASS_NAME "convert-std-to-llvm"
|
||||
|
||||
static llvm::cl::OptionCategory
|
||||
clOptionsCategory("Standard to LLVM lowering options");
|
||||
|
||||
static llvm::cl::opt<bool>
|
||||
clUseAlloca(PASS_NAME "-use-alloca",
|
||||
llvm::cl::desc("Replace emission of malloc/free by alloca"),
|
||||
llvm::cl::init(false));
|
||||
|
||||
LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx)
|
||||
: llvmDialect(ctx->getRegisteredDialect<LLVM::LLVMDialect>()) {
|
||||
assert(llvmDialect && "LLVM IR dialect is not registered");
|
||||
@ -764,6 +775,11 @@ static bool isSupportedMemRefType(MemRefType type) {
|
||||
struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
|
||||
using LLVMLegalizationPattern<AllocOp>::LLVMLegalizationPattern;
|
||||
|
||||
AllocOpLowering(LLVM::LLVMDialect &dialect_, LLVMTypeConverter &converter,
|
||||
bool useAlloca = false)
|
||||
: LLVMLegalizationPattern<AllocOp>(dialect_, converter),
|
||||
useAlloca(useAlloca) {}
|
||||
|
||||
PatternMatchResult match(Operation *op) const override {
|
||||
MemRefType type = cast<AllocOp>(op).getType();
|
||||
if (isSupportedMemRefType(type))
|
||||
@ -825,32 +841,43 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
|
||||
cumulativeSize = rewriter.create<LLVM::MulOp>(
|
||||
loc, getIndexType(), ArrayRef<Value *>{cumulativeSize, elementSize});
|
||||
|
||||
// Insert the `malloc` declaration if it is not already present.
|
||||
auto module = op->getParentOfType<ModuleOp>();
|
||||
auto mallocFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("malloc");
|
||||
if (!mallocFunc) {
|
||||
OpBuilder moduleBuilder(op->getParentOfType<ModuleOp>().getBodyRegion());
|
||||
mallocFunc = moduleBuilder.create<LLVM::LLVMFuncOp>(
|
||||
rewriter.getUnknownLoc(), "malloc",
|
||||
LLVM::LLVMType::getFunctionTy(getVoidPtrType(), getIndexType(),
|
||||
/*isVarArg=*/false));
|
||||
}
|
||||
|
||||
// Allocate the underlying buffer and store a pointer to it in the MemRef
|
||||
// descriptor.
|
||||
Value *align = nullptr;
|
||||
if (auto alignAttr = allocOp.alignment()) {
|
||||
align = createIndexConstant(rewriter, loc,
|
||||
alignAttr.getValue().getSExtValue());
|
||||
cumulativeSize = rewriter.create<LLVM::SubOp>(
|
||||
loc, rewriter.create<LLVM::AddOp>(loc, cumulativeSize, align), one);
|
||||
Value *allocated = nullptr;
|
||||
int alignment = 0;
|
||||
Value *alignmentValue = nullptr;
|
||||
if (auto alignAttr = allocOp.alignment())
|
||||
alignment = alignAttr.getValue().getSExtValue();
|
||||
|
||||
if (useAlloca) {
|
||||
allocated = rewriter.create<LLVM::AllocaOp>(loc, getVoidPtrType(),
|
||||
cumulativeSize, alignment);
|
||||
} else {
|
||||
// Insert the `malloc` declaration if it is not already present.
|
||||
auto module = op->getParentOfType<ModuleOp>();
|
||||
auto mallocFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("malloc");
|
||||
if (!mallocFunc) {
|
||||
OpBuilder moduleBuilder(
|
||||
op->getParentOfType<ModuleOp>().getBodyRegion());
|
||||
mallocFunc = moduleBuilder.create<LLVM::LLVMFuncOp>(
|
||||
rewriter.getUnknownLoc(), "malloc",
|
||||
LLVM::LLVMType::getFunctionTy(getVoidPtrType(), getIndexType(),
|
||||
/*isVarArg=*/false));
|
||||
}
|
||||
if (alignment != 0) {
|
||||
alignmentValue = createIndexConstant(rewriter, loc, alignment);
|
||||
cumulativeSize = rewriter.create<LLVM::SubOp>(
|
||||
loc,
|
||||
rewriter.create<LLVM::AddOp>(loc, cumulativeSize, alignmentValue),
|
||||
one);
|
||||
}
|
||||
allocated = rewriter
|
||||
.create<LLVM::CallOp>(
|
||||
loc, getVoidPtrType(),
|
||||
rewriter.getSymbolRefAttr(mallocFunc), cumulativeSize)
|
||||
.getResult(0);
|
||||
}
|
||||
Value *allocated =
|
||||
rewriter
|
||||
.create<LLVM::CallOp>(loc, getVoidPtrType(),
|
||||
rewriter.getSymbolRefAttr(mallocFunc),
|
||||
cumulativeSize)
|
||||
.getResult(0);
|
||||
|
||||
auto structElementType = lowering.convertType(elementType);
|
||||
auto elementPtrType = structElementType.cast<LLVM::LLVMType>().getPointerTo(
|
||||
type.getMemorySpace());
|
||||
@ -878,13 +905,17 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
|
||||
|
||||
// Field 2: Actual aligned pointer to payload.
|
||||
Value *bitcastAligned = bitcastAllocated;
|
||||
if (align) {
|
||||
if (!useAlloca && alignment != 0) {
|
||||
assert(alignmentValue);
|
||||
// offset = (align - (ptr % align))% align
|
||||
Value *intVal = rewriter.create<LLVM::PtrToIntOp>(
|
||||
loc, this->getIndexType(), allocated);
|
||||
Value *ptrModAlign = rewriter.create<LLVM::URemOp>(loc, intVal, align);
|
||||
Value *subbed = rewriter.create<LLVM::SubOp>(loc, align, ptrModAlign);
|
||||
Value *offset = rewriter.create<LLVM::URemOp>(loc, subbed, align);
|
||||
Value *ptrModAlign =
|
||||
rewriter.create<LLVM::URemOp>(loc, intVal, alignmentValue);
|
||||
Value *subbed =
|
||||
rewriter.create<LLVM::SubOp>(loc, alignmentValue, ptrModAlign);
|
||||
Value *offset =
|
||||
rewriter.create<LLVM::URemOp>(loc, subbed, alignmentValue);
|
||||
Value *aligned = rewriter.create<LLVM::GEPOp>(loc, allocated->getType(),
|
||||
allocated, offset);
|
||||
bitcastAligned = rewriter.create<LLVM::BitcastOp>(
|
||||
@ -930,6 +961,8 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
|
||||
// Return the final value of the descriptor.
|
||||
rewriter.replaceOp(op, {memRefDescriptor});
|
||||
}
|
||||
|
||||
bool useAlloca;
|
||||
};
|
||||
|
||||
// A CallOp automatically promotes MemRefType to a sequence of alloca/store and
|
||||
@ -1001,9 +1034,17 @@ struct CallIndirectOpLowering : public CallOpInterfaceLowering<CallIndirectOp> {
|
||||
struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
|
||||
using LLVMLegalizationPattern<DeallocOp>::LLVMLegalizationPattern;
|
||||
|
||||
DeallocOpLowering(LLVM::LLVMDialect &dialect_, LLVMTypeConverter &converter,
|
||||
bool useAlloca = false)
|
||||
: LLVMLegalizationPattern<DeallocOp>(dialect_, converter),
|
||||
useAlloca(useAlloca) {}
|
||||
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (useAlloca)
|
||||
return rewriter.eraseOp(op), matchSuccess();
|
||||
|
||||
assert(operands.size() == 1 && "dealloc takes one operand");
|
||||
OperandAdaptor<DeallocOp> transformed(operands);
|
||||
|
||||
@ -1026,6 +1067,8 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
|
||||
op, ArrayRef<Type>(), rewriter.getSymbolRefAttr(freeFunc), casted);
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
bool useAlloca;
|
||||
};
|
||||
|
||||
struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
|
||||
@ -1759,7 +1802,6 @@ void mlir::populateStdToLLVMConversionPatterns(
|
||||
patterns.insert<
|
||||
AddFOpLowering,
|
||||
AddIOpLowering,
|
||||
AllocOpLowering,
|
||||
AndOpLowering,
|
||||
BranchOpLowering,
|
||||
CallIndirectOpLowering,
|
||||
@ -1768,7 +1810,6 @@ void mlir::populateStdToLLVMConversionPatterns(
|
||||
CmpIOpLowering,
|
||||
CondBranchOpLowering,
|
||||
ConstLLVMOpLowering,
|
||||
DeallocOpLowering,
|
||||
DimOpLowering,
|
||||
DivFOpLowering,
|
||||
DivISOpLowering,
|
||||
@ -1800,6 +1841,10 @@ void mlir::populateStdToLLVMConversionPatterns(
|
||||
ViewOpLowering,
|
||||
XOrOpLowering,
|
||||
ZeroExtendIOpLowering>(*converter.getDialect(), converter);
|
||||
patterns.insert<
|
||||
AllocOpLowering,
|
||||
DeallocOpLowering>(
|
||||
*converter.getDialect(), converter, clUseAlloca.getValue());
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
@ -1873,6 +1918,7 @@ struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
|
||||
// By default, the patterns are those converting Standard operations to the
|
||||
// LLVMIR dialect.
|
||||
explicit LLVMLoweringPass(
|
||||
bool useAlloca = false,
|
||||
LLVMPatternListFiller patternListFiller =
|
||||
populateStdToLLVMConversionPatterns,
|
||||
LLVMTypeConverterMaker converterBuilder = makeStandardToLLVMTypeConverter)
|
||||
@ -1911,17 +1957,25 @@ struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
|
||||
};
|
||||
} // end namespace
|
||||
|
||||
std::unique_ptr<OpPassBase<ModuleOp>> mlir::createLowerToLLVMPass() {
|
||||
return std::make_unique<LLVMLoweringPass>();
|
||||
std::unique_ptr<OpPassBase<ModuleOp>>
|
||||
mlir::createLowerToLLVMPass(bool useAlloca) {
|
||||
return std::make_unique<LLVMLoweringPass>(useAlloca);
|
||||
}
|
||||
|
||||
std::unique_ptr<OpPassBase<ModuleOp>>
|
||||
mlir::createLowerToLLVMPass(LLVMPatternListFiller patternListFiller,
|
||||
LLVMTypeConverterMaker typeConverterMaker) {
|
||||
return std::make_unique<LLVMLoweringPass>(patternListFiller,
|
||||
LLVMTypeConverterMaker typeConverterMaker,
|
||||
bool useAlloca) {
|
||||
return std::make_unique<LLVMLoweringPass>(useAlloca, patternListFiller,
|
||||
typeConverterMaker);
|
||||
}
|
||||
|
||||
static PassRegistration<LLVMLoweringPass>
|
||||
pass("convert-std-to-llvm", "Convert scalar and vector operations from the "
|
||||
"Standard to the LLVM dialect");
|
||||
pass("convert-std-to-llvm",
|
||||
"Convert scalar and vector operations from the "
|
||||
"Standard to the LLVM dialect",
|
||||
[] {
|
||||
return std::make_unique<LLVMLoweringPass>(
|
||||
clUseAlloca.getValue(), populateStdToLLVMConversionPatterns,
|
||||
makeStandardToLLVMTypeConverter);
|
||||
});
|
||||
|
Loading…
Reference in New Issue
Block a user