Add a CL option to Standard to LLVM lowering to use alloca instead of malloc/free.

In the future, a more configurable malloc and free interface should be used and exposed via
extra parameters to the `createLowerToLLVMPass`. Until requirements are gathered, a simple CL flag allows generating code that runs successfully on hardware that cannot use the stdlib.

PiperOrigin-RevId: 283833424
Change-Id: I56115a960e7d5a1fc14cabdc71dd3e33d9f6812c
This commit is contained in:
Nicolas Vasilache 2019-12-04 14:15:24 -08:00 committed by TensorFlower Gardener
parent 15715cb2c8
commit 41228d7f14
2 changed files with 110 additions and 41 deletions

View File

@ -57,25 +57,40 @@ void populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter,
OwningRewritePatternList &patterns);
/// Creates a pass to convert the Standard dialect into the LLVMIR dialect.
std::unique_ptr<OpPassBase<ModuleOp>> createLowerToLLVMPass();
/// By default stdlib malloc/free are used for allocating MemRef payloads.
/// Specifying `useAlloca-true` emits stack allocations instead. In the future
/// this may become an enum when we have concrete uses for other options.
std::unique_ptr<OpPassBase<ModuleOp>>
createLowerToLLVMPass(bool useAlloca = false);
/// Creates a pass to convert operations to the LLVMIR dialect. The conversion
/// is defined by a list of patterns and a type converter that will be obtained
/// during the pass using the provided callbacks.
/// By default stdlib malloc/free are used for allocating MemRef payloads.
/// Specifying `useAlloca-true` emits stack allocations instead. In the future
/// this may become an enum when we have concrete uses for other options.
std::unique_ptr<OpPassBase<ModuleOp>>
createLowerToLLVMPass(LLVMPatternListFiller patternListFiller,
LLVMTypeConverterMaker typeConverterMaker);
LLVMTypeConverterMaker typeConverterMaker,
bool useAlloca = false);
/// Creates a pass to convert operations to the LLVMIR dialect. The conversion
/// is defined by a list of patterns obtained during the pass using the provided
/// callback and an optional type conversion class, an instance is created
/// during the pass.
/// By default stdlib malloc/free are used for allocating MemRef payloads.
/// Specifying `useAlloca-true` emits stack allocations instead. In the future
/// this may become an enum when we have concrete uses for other options.
template <typename TypeConverter = LLVMTypeConverter>
std::unique_ptr<OpPassBase<ModuleOp>>
createLowerToLLVMPass(LLVMPatternListFiller patternListFiller) {
return createLowerToLLVMPass(patternListFiller, [](MLIRContext *context) {
return std::make_unique<TypeConverter>(context);
});
createLowerToLLVMPass(LLVMPatternListFiller patternListFiller,
bool useAlloca = false) {
return createLowerToLLVMPass(
patternListFiller,
[](MLIRContext *context) {
return std::make_unique<TypeConverter>(context);
},
useAlloca);
}
namespace LLVM {

View File

@ -38,9 +38,20 @@
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Type.h"
#include "llvm/Support/CommandLine.h"
using namespace mlir;
#define PASS_NAME "convert-std-to-llvm"
static llvm::cl::OptionCategory
clOptionsCategory("Standard to LLVM lowering options");
static llvm::cl::opt<bool>
clUseAlloca(PASS_NAME "-use-alloca",
llvm::cl::desc("Replace emission of malloc/free by alloca"),
llvm::cl::init(false));
LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx)
: llvmDialect(ctx->getRegisteredDialect<LLVM::LLVMDialect>()) {
assert(llvmDialect && "LLVM IR dialect is not registered");
@ -764,6 +775,11 @@ static bool isSupportedMemRefType(MemRefType type) {
struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
using LLVMLegalizationPattern<AllocOp>::LLVMLegalizationPattern;
AllocOpLowering(LLVM::LLVMDialect &dialect_, LLVMTypeConverter &converter,
bool useAlloca = false)
: LLVMLegalizationPattern<AllocOp>(dialect_, converter),
useAlloca(useAlloca) {}
PatternMatchResult match(Operation *op) const override {
MemRefType type = cast<AllocOp>(op).getType();
if (isSupportedMemRefType(type))
@ -825,32 +841,43 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
cumulativeSize = rewriter.create<LLVM::MulOp>(
loc, getIndexType(), ArrayRef<Value *>{cumulativeSize, elementSize});
// Insert the `malloc` declaration if it is not already present.
auto module = op->getParentOfType<ModuleOp>();
auto mallocFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("malloc");
if (!mallocFunc) {
OpBuilder moduleBuilder(op->getParentOfType<ModuleOp>().getBodyRegion());
mallocFunc = moduleBuilder.create<LLVM::LLVMFuncOp>(
rewriter.getUnknownLoc(), "malloc",
LLVM::LLVMType::getFunctionTy(getVoidPtrType(), getIndexType(),
/*isVarArg=*/false));
}
// Allocate the underlying buffer and store a pointer to it in the MemRef
// descriptor.
Value *align = nullptr;
if (auto alignAttr = allocOp.alignment()) {
align = createIndexConstant(rewriter, loc,
alignAttr.getValue().getSExtValue());
cumulativeSize = rewriter.create<LLVM::SubOp>(
loc, rewriter.create<LLVM::AddOp>(loc, cumulativeSize, align), one);
Value *allocated = nullptr;
int alignment = 0;
Value *alignmentValue = nullptr;
if (auto alignAttr = allocOp.alignment())
alignment = alignAttr.getValue().getSExtValue();
if (useAlloca) {
allocated = rewriter.create<LLVM::AllocaOp>(loc, getVoidPtrType(),
cumulativeSize, alignment);
} else {
// Insert the `malloc` declaration if it is not already present.
auto module = op->getParentOfType<ModuleOp>();
auto mallocFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("malloc");
if (!mallocFunc) {
OpBuilder moduleBuilder(
op->getParentOfType<ModuleOp>().getBodyRegion());
mallocFunc = moduleBuilder.create<LLVM::LLVMFuncOp>(
rewriter.getUnknownLoc(), "malloc",
LLVM::LLVMType::getFunctionTy(getVoidPtrType(), getIndexType(),
/*isVarArg=*/false));
}
if (alignment != 0) {
alignmentValue = createIndexConstant(rewriter, loc, alignment);
cumulativeSize = rewriter.create<LLVM::SubOp>(
loc,
rewriter.create<LLVM::AddOp>(loc, cumulativeSize, alignmentValue),
one);
}
allocated = rewriter
.create<LLVM::CallOp>(
loc, getVoidPtrType(),
rewriter.getSymbolRefAttr(mallocFunc), cumulativeSize)
.getResult(0);
}
Value *allocated =
rewriter
.create<LLVM::CallOp>(loc, getVoidPtrType(),
rewriter.getSymbolRefAttr(mallocFunc),
cumulativeSize)
.getResult(0);
auto structElementType = lowering.convertType(elementType);
auto elementPtrType = structElementType.cast<LLVM::LLVMType>().getPointerTo(
type.getMemorySpace());
@ -878,13 +905,17 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
// Field 2: Actual aligned pointer to payload.
Value *bitcastAligned = bitcastAllocated;
if (align) {
if (!useAlloca && alignment != 0) {
assert(alignmentValue);
// offset = (align - (ptr % align))% align
Value *intVal = rewriter.create<LLVM::PtrToIntOp>(
loc, this->getIndexType(), allocated);
Value *ptrModAlign = rewriter.create<LLVM::URemOp>(loc, intVal, align);
Value *subbed = rewriter.create<LLVM::SubOp>(loc, align, ptrModAlign);
Value *offset = rewriter.create<LLVM::URemOp>(loc, subbed, align);
Value *ptrModAlign =
rewriter.create<LLVM::URemOp>(loc, intVal, alignmentValue);
Value *subbed =
rewriter.create<LLVM::SubOp>(loc, alignmentValue, ptrModAlign);
Value *offset =
rewriter.create<LLVM::URemOp>(loc, subbed, alignmentValue);
Value *aligned = rewriter.create<LLVM::GEPOp>(loc, allocated->getType(),
allocated, offset);
bitcastAligned = rewriter.create<LLVM::BitcastOp>(
@ -930,6 +961,8 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
// Return the final value of the descriptor.
rewriter.replaceOp(op, {memRefDescriptor});
}
bool useAlloca;
};
// A CallOp automatically promotes MemRefType to a sequence of alloca/store and
@ -1001,9 +1034,17 @@ struct CallIndirectOpLowering : public CallOpInterfaceLowering<CallIndirectOp> {
struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
using LLVMLegalizationPattern<DeallocOp>::LLVMLegalizationPattern;
DeallocOpLowering(LLVM::LLVMDialect &dialect_, LLVMTypeConverter &converter,
bool useAlloca = false)
: LLVMLegalizationPattern<DeallocOp>(dialect_, converter),
useAlloca(useAlloca) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
if (useAlloca)
return rewriter.eraseOp(op), matchSuccess();
assert(operands.size() == 1 && "dealloc takes one operand");
OperandAdaptor<DeallocOp> transformed(operands);
@ -1026,6 +1067,8 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
op, ArrayRef<Type>(), rewriter.getSymbolRefAttr(freeFunc), casted);
return matchSuccess();
}
bool useAlloca;
};
struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
@ -1759,7 +1802,6 @@ void mlir::populateStdToLLVMConversionPatterns(
patterns.insert<
AddFOpLowering,
AddIOpLowering,
AllocOpLowering,
AndOpLowering,
BranchOpLowering,
CallIndirectOpLowering,
@ -1768,7 +1810,6 @@ void mlir::populateStdToLLVMConversionPatterns(
CmpIOpLowering,
CondBranchOpLowering,
ConstLLVMOpLowering,
DeallocOpLowering,
DimOpLowering,
DivFOpLowering,
DivISOpLowering,
@ -1800,6 +1841,10 @@ void mlir::populateStdToLLVMConversionPatterns(
ViewOpLowering,
XOrOpLowering,
ZeroExtendIOpLowering>(*converter.getDialect(), converter);
patterns.insert<
AllocOpLowering,
DeallocOpLowering>(
*converter.getDialect(), converter, clUseAlloca.getValue());
// clang-format on
}
@ -1873,6 +1918,7 @@ struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
// By default, the patterns are those converting Standard operations to the
// LLVMIR dialect.
explicit LLVMLoweringPass(
bool useAlloca = false,
LLVMPatternListFiller patternListFiller =
populateStdToLLVMConversionPatterns,
LLVMTypeConverterMaker converterBuilder = makeStandardToLLVMTypeConverter)
@ -1911,17 +1957,25 @@ struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
};
} // end namespace
std::unique_ptr<OpPassBase<ModuleOp>> mlir::createLowerToLLVMPass() {
return std::make_unique<LLVMLoweringPass>();
std::unique_ptr<OpPassBase<ModuleOp>>
mlir::createLowerToLLVMPass(bool useAlloca) {
return std::make_unique<LLVMLoweringPass>(useAlloca);
}
std::unique_ptr<OpPassBase<ModuleOp>>
mlir::createLowerToLLVMPass(LLVMPatternListFiller patternListFiller,
LLVMTypeConverterMaker typeConverterMaker) {
return std::make_unique<LLVMLoweringPass>(patternListFiller,
LLVMTypeConverterMaker typeConverterMaker,
bool useAlloca) {
return std::make_unique<LLVMLoweringPass>(useAlloca, patternListFiller,
typeConverterMaker);
}
static PassRegistration<LLVMLoweringPass>
pass("convert-std-to-llvm", "Convert scalar and vector operations from the "
"Standard to the LLVM dialect");
pass("convert-std-to-llvm",
"Convert scalar and vector operations from the "
"Standard to the LLVM dialect",
[] {
return std::make_unique<LLVMLoweringPass>(
clUseAlloca.getValue(), populateStdToLLVMConversionPatterns,
makeStandardToLLVMTypeConverter);
});