Add a CL option to Standard to LLVM lowering to use alloca instead of malloc/free.
In the future, a more configurable malloc and free interface should be used and exposed via extra parameters to the `createLowerToLLVMPass`. Until requirements are gathered, a simple CL flag allows generating code that runs successfully on hardware that cannot use the stdlib. PiperOrigin-RevId: 283833424 Change-Id: I56115a960e7d5a1fc14cabdc71dd3e33d9f6812c
This commit is contained in:
parent
15715cb2c8
commit
41228d7f14
third_party/mlir
include/mlir/Conversion/StandardToLLVM
lib/Conversion/StandardToLLVM
@ -57,25 +57,40 @@ void populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter,
|
|||||||
OwningRewritePatternList &patterns);
|
OwningRewritePatternList &patterns);
|
||||||
|
|
||||||
/// Creates a pass to convert the Standard dialect into the LLVMIR dialect.
|
/// Creates a pass to convert the Standard dialect into the LLVMIR dialect.
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>> createLowerToLLVMPass();
|
/// By default stdlib malloc/free are used for allocating MemRef payloads.
|
||||||
|
/// Specifying `useAlloca-true` emits stack allocations instead. In the future
|
||||||
|
/// this may become an enum when we have concrete uses for other options.
|
||||||
|
std::unique_ptr<OpPassBase<ModuleOp>>
|
||||||
|
createLowerToLLVMPass(bool useAlloca = false);
|
||||||
|
|
||||||
/// Creates a pass to convert operations to the LLVMIR dialect. The conversion
|
/// Creates a pass to convert operations to the LLVMIR dialect. The conversion
|
||||||
/// is defined by a list of patterns and a type converter that will be obtained
|
/// is defined by a list of patterns and a type converter that will be obtained
|
||||||
/// during the pass using the provided callbacks.
|
/// during the pass using the provided callbacks.
|
||||||
|
/// By default stdlib malloc/free are used for allocating MemRef payloads.
|
||||||
|
/// Specifying `useAlloca-true` emits stack allocations instead. In the future
|
||||||
|
/// this may become an enum when we have concrete uses for other options.
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>>
|
std::unique_ptr<OpPassBase<ModuleOp>>
|
||||||
createLowerToLLVMPass(LLVMPatternListFiller patternListFiller,
|
createLowerToLLVMPass(LLVMPatternListFiller patternListFiller,
|
||||||
LLVMTypeConverterMaker typeConverterMaker);
|
LLVMTypeConverterMaker typeConverterMaker,
|
||||||
|
bool useAlloca = false);
|
||||||
|
|
||||||
/// Creates a pass to convert operations to the LLVMIR dialect. The conversion
|
/// Creates a pass to convert operations to the LLVMIR dialect. The conversion
|
||||||
/// is defined by a list of patterns obtained during the pass using the provided
|
/// is defined by a list of patterns obtained during the pass using the provided
|
||||||
/// callback and an optional type conversion class, an instance is created
|
/// callback and an optional type conversion class, an instance is created
|
||||||
/// during the pass.
|
/// during the pass.
|
||||||
|
/// By default stdlib malloc/free are used for allocating MemRef payloads.
|
||||||
|
/// Specifying `useAlloca-true` emits stack allocations instead. In the future
|
||||||
|
/// this may become an enum when we have concrete uses for other options.
|
||||||
template <typename TypeConverter = LLVMTypeConverter>
|
template <typename TypeConverter = LLVMTypeConverter>
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>>
|
std::unique_ptr<OpPassBase<ModuleOp>>
|
||||||
createLowerToLLVMPass(LLVMPatternListFiller patternListFiller) {
|
createLowerToLLVMPass(LLVMPatternListFiller patternListFiller,
|
||||||
return createLowerToLLVMPass(patternListFiller, [](MLIRContext *context) {
|
bool useAlloca = false) {
|
||||||
return std::make_unique<TypeConverter>(context);
|
return createLowerToLLVMPass(
|
||||||
});
|
patternListFiller,
|
||||||
|
[](MLIRContext *context) {
|
||||||
|
return std::make_unique<TypeConverter>(context);
|
||||||
|
},
|
||||||
|
useAlloca);
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace LLVM {
|
namespace LLVM {
|
||||||
|
@ -38,9 +38,20 @@
|
|||||||
#include "llvm/IR/DerivedTypes.h"
|
#include "llvm/IR/DerivedTypes.h"
|
||||||
#include "llvm/IR/IRBuilder.h"
|
#include "llvm/IR/IRBuilder.h"
|
||||||
#include "llvm/IR/Type.h"
|
#include "llvm/IR/Type.h"
|
||||||
|
#include "llvm/Support/CommandLine.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
|
#define PASS_NAME "convert-std-to-llvm"
|
||||||
|
|
||||||
|
static llvm::cl::OptionCategory
|
||||||
|
clOptionsCategory("Standard to LLVM lowering options");
|
||||||
|
|
||||||
|
static llvm::cl::opt<bool>
|
||||||
|
clUseAlloca(PASS_NAME "-use-alloca",
|
||||||
|
llvm::cl::desc("Replace emission of malloc/free by alloca"),
|
||||||
|
llvm::cl::init(false));
|
||||||
|
|
||||||
LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx)
|
LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx)
|
||||||
: llvmDialect(ctx->getRegisteredDialect<LLVM::LLVMDialect>()) {
|
: llvmDialect(ctx->getRegisteredDialect<LLVM::LLVMDialect>()) {
|
||||||
assert(llvmDialect && "LLVM IR dialect is not registered");
|
assert(llvmDialect && "LLVM IR dialect is not registered");
|
||||||
@ -764,6 +775,11 @@ static bool isSupportedMemRefType(MemRefType type) {
|
|||||||
struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
|
struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
|
||||||
using LLVMLegalizationPattern<AllocOp>::LLVMLegalizationPattern;
|
using LLVMLegalizationPattern<AllocOp>::LLVMLegalizationPattern;
|
||||||
|
|
||||||
|
AllocOpLowering(LLVM::LLVMDialect &dialect_, LLVMTypeConverter &converter,
|
||||||
|
bool useAlloca = false)
|
||||||
|
: LLVMLegalizationPattern<AllocOp>(dialect_, converter),
|
||||||
|
useAlloca(useAlloca) {}
|
||||||
|
|
||||||
PatternMatchResult match(Operation *op) const override {
|
PatternMatchResult match(Operation *op) const override {
|
||||||
MemRefType type = cast<AllocOp>(op).getType();
|
MemRefType type = cast<AllocOp>(op).getType();
|
||||||
if (isSupportedMemRefType(type))
|
if (isSupportedMemRefType(type))
|
||||||
@ -825,32 +841,43 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
|
|||||||
cumulativeSize = rewriter.create<LLVM::MulOp>(
|
cumulativeSize = rewriter.create<LLVM::MulOp>(
|
||||||
loc, getIndexType(), ArrayRef<Value *>{cumulativeSize, elementSize});
|
loc, getIndexType(), ArrayRef<Value *>{cumulativeSize, elementSize});
|
||||||
|
|
||||||
// Insert the `malloc` declaration if it is not already present.
|
|
||||||
auto module = op->getParentOfType<ModuleOp>();
|
|
||||||
auto mallocFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("malloc");
|
|
||||||
if (!mallocFunc) {
|
|
||||||
OpBuilder moduleBuilder(op->getParentOfType<ModuleOp>().getBodyRegion());
|
|
||||||
mallocFunc = moduleBuilder.create<LLVM::LLVMFuncOp>(
|
|
||||||
rewriter.getUnknownLoc(), "malloc",
|
|
||||||
LLVM::LLVMType::getFunctionTy(getVoidPtrType(), getIndexType(),
|
|
||||||
/*isVarArg=*/false));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Allocate the underlying buffer and store a pointer to it in the MemRef
|
// Allocate the underlying buffer and store a pointer to it in the MemRef
|
||||||
// descriptor.
|
// descriptor.
|
||||||
Value *align = nullptr;
|
Value *allocated = nullptr;
|
||||||
if (auto alignAttr = allocOp.alignment()) {
|
int alignment = 0;
|
||||||
align = createIndexConstant(rewriter, loc,
|
Value *alignmentValue = nullptr;
|
||||||
alignAttr.getValue().getSExtValue());
|
if (auto alignAttr = allocOp.alignment())
|
||||||
cumulativeSize = rewriter.create<LLVM::SubOp>(
|
alignment = alignAttr.getValue().getSExtValue();
|
||||||
loc, rewriter.create<LLVM::AddOp>(loc, cumulativeSize, align), one);
|
|
||||||
|
if (useAlloca) {
|
||||||
|
allocated = rewriter.create<LLVM::AllocaOp>(loc, getVoidPtrType(),
|
||||||
|
cumulativeSize, alignment);
|
||||||
|
} else {
|
||||||
|
// Insert the `malloc` declaration if it is not already present.
|
||||||
|
auto module = op->getParentOfType<ModuleOp>();
|
||||||
|
auto mallocFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("malloc");
|
||||||
|
if (!mallocFunc) {
|
||||||
|
OpBuilder moduleBuilder(
|
||||||
|
op->getParentOfType<ModuleOp>().getBodyRegion());
|
||||||
|
mallocFunc = moduleBuilder.create<LLVM::LLVMFuncOp>(
|
||||||
|
rewriter.getUnknownLoc(), "malloc",
|
||||||
|
LLVM::LLVMType::getFunctionTy(getVoidPtrType(), getIndexType(),
|
||||||
|
/*isVarArg=*/false));
|
||||||
|
}
|
||||||
|
if (alignment != 0) {
|
||||||
|
alignmentValue = createIndexConstant(rewriter, loc, alignment);
|
||||||
|
cumulativeSize = rewriter.create<LLVM::SubOp>(
|
||||||
|
loc,
|
||||||
|
rewriter.create<LLVM::AddOp>(loc, cumulativeSize, alignmentValue),
|
||||||
|
one);
|
||||||
|
}
|
||||||
|
allocated = rewriter
|
||||||
|
.create<LLVM::CallOp>(
|
||||||
|
loc, getVoidPtrType(),
|
||||||
|
rewriter.getSymbolRefAttr(mallocFunc), cumulativeSize)
|
||||||
|
.getResult(0);
|
||||||
}
|
}
|
||||||
Value *allocated =
|
|
||||||
rewriter
|
|
||||||
.create<LLVM::CallOp>(loc, getVoidPtrType(),
|
|
||||||
rewriter.getSymbolRefAttr(mallocFunc),
|
|
||||||
cumulativeSize)
|
|
||||||
.getResult(0);
|
|
||||||
auto structElementType = lowering.convertType(elementType);
|
auto structElementType = lowering.convertType(elementType);
|
||||||
auto elementPtrType = structElementType.cast<LLVM::LLVMType>().getPointerTo(
|
auto elementPtrType = structElementType.cast<LLVM::LLVMType>().getPointerTo(
|
||||||
type.getMemorySpace());
|
type.getMemorySpace());
|
||||||
@ -878,13 +905,17 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
|
|||||||
|
|
||||||
// Field 2: Actual aligned pointer to payload.
|
// Field 2: Actual aligned pointer to payload.
|
||||||
Value *bitcastAligned = bitcastAllocated;
|
Value *bitcastAligned = bitcastAllocated;
|
||||||
if (align) {
|
if (!useAlloca && alignment != 0) {
|
||||||
|
assert(alignmentValue);
|
||||||
// offset = (align - (ptr % align))% align
|
// offset = (align - (ptr % align))% align
|
||||||
Value *intVal = rewriter.create<LLVM::PtrToIntOp>(
|
Value *intVal = rewriter.create<LLVM::PtrToIntOp>(
|
||||||
loc, this->getIndexType(), allocated);
|
loc, this->getIndexType(), allocated);
|
||||||
Value *ptrModAlign = rewriter.create<LLVM::URemOp>(loc, intVal, align);
|
Value *ptrModAlign =
|
||||||
Value *subbed = rewriter.create<LLVM::SubOp>(loc, align, ptrModAlign);
|
rewriter.create<LLVM::URemOp>(loc, intVal, alignmentValue);
|
||||||
Value *offset = rewriter.create<LLVM::URemOp>(loc, subbed, align);
|
Value *subbed =
|
||||||
|
rewriter.create<LLVM::SubOp>(loc, alignmentValue, ptrModAlign);
|
||||||
|
Value *offset =
|
||||||
|
rewriter.create<LLVM::URemOp>(loc, subbed, alignmentValue);
|
||||||
Value *aligned = rewriter.create<LLVM::GEPOp>(loc, allocated->getType(),
|
Value *aligned = rewriter.create<LLVM::GEPOp>(loc, allocated->getType(),
|
||||||
allocated, offset);
|
allocated, offset);
|
||||||
bitcastAligned = rewriter.create<LLVM::BitcastOp>(
|
bitcastAligned = rewriter.create<LLVM::BitcastOp>(
|
||||||
@ -930,6 +961,8 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
|
|||||||
// Return the final value of the descriptor.
|
// Return the final value of the descriptor.
|
||||||
rewriter.replaceOp(op, {memRefDescriptor});
|
rewriter.replaceOp(op, {memRefDescriptor});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool useAlloca;
|
||||||
};
|
};
|
||||||
|
|
||||||
// A CallOp automatically promotes MemRefType to a sequence of alloca/store and
|
// A CallOp automatically promotes MemRefType to a sequence of alloca/store and
|
||||||
@ -1001,9 +1034,17 @@ struct CallIndirectOpLowering : public CallOpInterfaceLowering<CallIndirectOp> {
|
|||||||
struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
|
struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
|
||||||
using LLVMLegalizationPattern<DeallocOp>::LLVMLegalizationPattern;
|
using LLVMLegalizationPattern<DeallocOp>::LLVMLegalizationPattern;
|
||||||
|
|
||||||
|
DeallocOpLowering(LLVM::LLVMDialect &dialect_, LLVMTypeConverter &converter,
|
||||||
|
bool useAlloca = false)
|
||||||
|
: LLVMLegalizationPattern<DeallocOp>(dialect_, converter),
|
||||||
|
useAlloca(useAlloca) {}
|
||||||
|
|
||||||
PatternMatchResult
|
PatternMatchResult
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
if (useAlloca)
|
||||||
|
return rewriter.eraseOp(op), matchSuccess();
|
||||||
|
|
||||||
assert(operands.size() == 1 && "dealloc takes one operand");
|
assert(operands.size() == 1 && "dealloc takes one operand");
|
||||||
OperandAdaptor<DeallocOp> transformed(operands);
|
OperandAdaptor<DeallocOp> transformed(operands);
|
||||||
|
|
||||||
@ -1026,6 +1067,8 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
|
|||||||
op, ArrayRef<Type>(), rewriter.getSymbolRefAttr(freeFunc), casted);
|
op, ArrayRef<Type>(), rewriter.getSymbolRefAttr(freeFunc), casted);
|
||||||
return matchSuccess();
|
return matchSuccess();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool useAlloca;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
|
struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
|
||||||
@ -1759,7 +1802,6 @@ void mlir::populateStdToLLVMConversionPatterns(
|
|||||||
patterns.insert<
|
patterns.insert<
|
||||||
AddFOpLowering,
|
AddFOpLowering,
|
||||||
AddIOpLowering,
|
AddIOpLowering,
|
||||||
AllocOpLowering,
|
|
||||||
AndOpLowering,
|
AndOpLowering,
|
||||||
BranchOpLowering,
|
BranchOpLowering,
|
||||||
CallIndirectOpLowering,
|
CallIndirectOpLowering,
|
||||||
@ -1768,7 +1810,6 @@ void mlir::populateStdToLLVMConversionPatterns(
|
|||||||
CmpIOpLowering,
|
CmpIOpLowering,
|
||||||
CondBranchOpLowering,
|
CondBranchOpLowering,
|
||||||
ConstLLVMOpLowering,
|
ConstLLVMOpLowering,
|
||||||
DeallocOpLowering,
|
|
||||||
DimOpLowering,
|
DimOpLowering,
|
||||||
DivFOpLowering,
|
DivFOpLowering,
|
||||||
DivISOpLowering,
|
DivISOpLowering,
|
||||||
@ -1800,6 +1841,10 @@ void mlir::populateStdToLLVMConversionPatterns(
|
|||||||
ViewOpLowering,
|
ViewOpLowering,
|
||||||
XOrOpLowering,
|
XOrOpLowering,
|
||||||
ZeroExtendIOpLowering>(*converter.getDialect(), converter);
|
ZeroExtendIOpLowering>(*converter.getDialect(), converter);
|
||||||
|
patterns.insert<
|
||||||
|
AllocOpLowering,
|
||||||
|
DeallocOpLowering>(
|
||||||
|
*converter.getDialect(), converter, clUseAlloca.getValue());
|
||||||
// clang-format on
|
// clang-format on
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1873,6 +1918,7 @@ struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
|
|||||||
// By default, the patterns are those converting Standard operations to the
|
// By default, the patterns are those converting Standard operations to the
|
||||||
// LLVMIR dialect.
|
// LLVMIR dialect.
|
||||||
explicit LLVMLoweringPass(
|
explicit LLVMLoweringPass(
|
||||||
|
bool useAlloca = false,
|
||||||
LLVMPatternListFiller patternListFiller =
|
LLVMPatternListFiller patternListFiller =
|
||||||
populateStdToLLVMConversionPatterns,
|
populateStdToLLVMConversionPatterns,
|
||||||
LLVMTypeConverterMaker converterBuilder = makeStandardToLLVMTypeConverter)
|
LLVMTypeConverterMaker converterBuilder = makeStandardToLLVMTypeConverter)
|
||||||
@ -1911,17 +1957,25 @@ struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
|
|||||||
};
|
};
|
||||||
} // end namespace
|
} // end namespace
|
||||||
|
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>> mlir::createLowerToLLVMPass() {
|
std::unique_ptr<OpPassBase<ModuleOp>>
|
||||||
return std::make_unique<LLVMLoweringPass>();
|
mlir::createLowerToLLVMPass(bool useAlloca) {
|
||||||
|
return std::make_unique<LLVMLoweringPass>(useAlloca);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>>
|
std::unique_ptr<OpPassBase<ModuleOp>>
|
||||||
mlir::createLowerToLLVMPass(LLVMPatternListFiller patternListFiller,
|
mlir::createLowerToLLVMPass(LLVMPatternListFiller patternListFiller,
|
||||||
LLVMTypeConverterMaker typeConverterMaker) {
|
LLVMTypeConverterMaker typeConverterMaker,
|
||||||
return std::make_unique<LLVMLoweringPass>(patternListFiller,
|
bool useAlloca) {
|
||||||
|
return std::make_unique<LLVMLoweringPass>(useAlloca, patternListFiller,
|
||||||
typeConverterMaker);
|
typeConverterMaker);
|
||||||
}
|
}
|
||||||
|
|
||||||
static PassRegistration<LLVMLoweringPass>
|
static PassRegistration<LLVMLoweringPass>
|
||||||
pass("convert-std-to-llvm", "Convert scalar and vector operations from the "
|
pass("convert-std-to-llvm",
|
||||||
"Standard to the LLVM dialect");
|
"Convert scalar and vector operations from the "
|
||||||
|
"Standard to the LLVM dialect",
|
||||||
|
[] {
|
||||||
|
return std::make_unique<LLVMLoweringPass>(
|
||||||
|
clUseAlloca.getValue(), populateStdToLLVMConversionPatterns,
|
||||||
|
makeStandardToLLVMTypeConverter);
|
||||||
|
});
|
||||||
|
Loading…
Reference in New Issue
Block a user