diff --git a/third_party/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h b/third_party/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h index 98e105aa2b5..c5c17b36f5e 100644 --- a/third_party/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h +++ b/third_party/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h @@ -57,25 +57,40 @@ void populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter, OwningRewritePatternList &patterns); /// Creates a pass to convert the Standard dialect into the LLVMIR dialect. -std::unique_ptr> createLowerToLLVMPass(); +/// By default stdlib malloc/free are used for allocating MemRef payloads. +/// Specifying `useAlloca-true` emits stack allocations instead. In the future +/// this may become an enum when we have concrete uses for other options. +std::unique_ptr> +createLowerToLLVMPass(bool useAlloca = false); /// Creates a pass to convert operations to the LLVMIR dialect. The conversion /// is defined by a list of patterns and a type converter that will be obtained /// during the pass using the provided callbacks. +/// By default stdlib malloc/free are used for allocating MemRef payloads. +/// Specifying `useAlloca-true` emits stack allocations instead. In the future +/// this may become an enum when we have concrete uses for other options. std::unique_ptr> createLowerToLLVMPass(LLVMPatternListFiller patternListFiller, - LLVMTypeConverterMaker typeConverterMaker); + LLVMTypeConverterMaker typeConverterMaker, + bool useAlloca = false); /// Creates a pass to convert operations to the LLVMIR dialect. The conversion /// is defined by a list of patterns obtained during the pass using the provided /// callback and an optional type conversion class, an instance is created /// during the pass. +/// By default stdlib malloc/free are used for allocating MemRef payloads. +/// Specifying `useAlloca-true` emits stack allocations instead. In the future +/// this may become an enum when we have concrete uses for other options. template std::unique_ptr> -createLowerToLLVMPass(LLVMPatternListFiller patternListFiller) { - return createLowerToLLVMPass(patternListFiller, [](MLIRContext *context) { - return std::make_unique(context); - }); +createLowerToLLVMPass(LLVMPatternListFiller patternListFiller, + bool useAlloca = false) { + return createLowerToLLVMPass( + patternListFiller, + [](MLIRContext *context) { + return std::make_unique(context); + }, + useAlloca); } namespace LLVM { diff --git a/third_party/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/third_party/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index 793997e9045..23c7be310a9 100644 --- a/third_party/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/third_party/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -38,9 +38,20 @@ #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Type.h" +#include "llvm/Support/CommandLine.h" using namespace mlir; +#define PASS_NAME "convert-std-to-llvm" + +static llvm::cl::OptionCategory + clOptionsCategory("Standard to LLVM lowering options"); + +static llvm::cl::opt + clUseAlloca(PASS_NAME "-use-alloca", + llvm::cl::desc("Replace emission of malloc/free by alloca"), + llvm::cl::init(false)); + LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx) : llvmDialect(ctx->getRegisteredDialect()) { assert(llvmDialect && "LLVM IR dialect is not registered"); @@ -764,6 +775,11 @@ static bool isSupportedMemRefType(MemRefType type) { struct AllocOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; + AllocOpLowering(LLVM::LLVMDialect &dialect_, LLVMTypeConverter &converter, + bool useAlloca = false) + : LLVMLegalizationPattern(dialect_, converter), + useAlloca(useAlloca) {} + PatternMatchResult match(Operation *op) const override { MemRefType type = cast(op).getType(); if (isSupportedMemRefType(type)) @@ -825,32 +841,43 @@ struct AllocOpLowering : public LLVMLegalizationPattern { cumulativeSize = rewriter.create( loc, getIndexType(), ArrayRef{cumulativeSize, elementSize}); - // Insert the `malloc` declaration if it is not already present. - auto module = op->getParentOfType(); - auto mallocFunc = module.lookupSymbol("malloc"); - if (!mallocFunc) { - OpBuilder moduleBuilder(op->getParentOfType().getBodyRegion()); - mallocFunc = moduleBuilder.create( - rewriter.getUnknownLoc(), "malloc", - LLVM::LLVMType::getFunctionTy(getVoidPtrType(), getIndexType(), - /*isVarArg=*/false)); - } - // Allocate the underlying buffer and store a pointer to it in the MemRef // descriptor. - Value *align = nullptr; - if (auto alignAttr = allocOp.alignment()) { - align = createIndexConstant(rewriter, loc, - alignAttr.getValue().getSExtValue()); - cumulativeSize = rewriter.create( - loc, rewriter.create(loc, cumulativeSize, align), one); + Value *allocated = nullptr; + int alignment = 0; + Value *alignmentValue = nullptr; + if (auto alignAttr = allocOp.alignment()) + alignment = alignAttr.getValue().getSExtValue(); + + if (useAlloca) { + allocated = rewriter.create(loc, getVoidPtrType(), + cumulativeSize, alignment); + } else { + // Insert the `malloc` declaration if it is not already present. + auto module = op->getParentOfType(); + auto mallocFunc = module.lookupSymbol("malloc"); + if (!mallocFunc) { + OpBuilder moduleBuilder( + op->getParentOfType().getBodyRegion()); + mallocFunc = moduleBuilder.create( + rewriter.getUnknownLoc(), "malloc", + LLVM::LLVMType::getFunctionTy(getVoidPtrType(), getIndexType(), + /*isVarArg=*/false)); + } + if (alignment != 0) { + alignmentValue = createIndexConstant(rewriter, loc, alignment); + cumulativeSize = rewriter.create( + loc, + rewriter.create(loc, cumulativeSize, alignmentValue), + one); + } + allocated = rewriter + .create( + loc, getVoidPtrType(), + rewriter.getSymbolRefAttr(mallocFunc), cumulativeSize) + .getResult(0); } - Value *allocated = - rewriter - .create(loc, getVoidPtrType(), - rewriter.getSymbolRefAttr(mallocFunc), - cumulativeSize) - .getResult(0); + auto structElementType = lowering.convertType(elementType); auto elementPtrType = structElementType.cast().getPointerTo( type.getMemorySpace()); @@ -878,13 +905,17 @@ struct AllocOpLowering : public LLVMLegalizationPattern { // Field 2: Actual aligned pointer to payload. Value *bitcastAligned = bitcastAllocated; - if (align) { + if (!useAlloca && alignment != 0) { + assert(alignmentValue); // offset = (align - (ptr % align))% align Value *intVal = rewriter.create( loc, this->getIndexType(), allocated); - Value *ptrModAlign = rewriter.create(loc, intVal, align); - Value *subbed = rewriter.create(loc, align, ptrModAlign); - Value *offset = rewriter.create(loc, subbed, align); + Value *ptrModAlign = + rewriter.create(loc, intVal, alignmentValue); + Value *subbed = + rewriter.create(loc, alignmentValue, ptrModAlign); + Value *offset = + rewriter.create(loc, subbed, alignmentValue); Value *aligned = rewriter.create(loc, allocated->getType(), allocated, offset); bitcastAligned = rewriter.create( @@ -930,6 +961,8 @@ struct AllocOpLowering : public LLVMLegalizationPattern { // Return the final value of the descriptor. rewriter.replaceOp(op, {memRefDescriptor}); } + + bool useAlloca; }; // A CallOp automatically promotes MemRefType to a sequence of alloca/store and @@ -1001,9 +1034,17 @@ struct CallIndirectOpLowering : public CallOpInterfaceLowering { struct DeallocOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; + DeallocOpLowering(LLVM::LLVMDialect &dialect_, LLVMTypeConverter &converter, + bool useAlloca = false) + : LLVMLegalizationPattern(dialect_, converter), + useAlloca(useAlloca) {} + PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { + if (useAlloca) + return rewriter.eraseOp(op), matchSuccess(); + assert(operands.size() == 1 && "dealloc takes one operand"); OperandAdaptor transformed(operands); @@ -1026,6 +1067,8 @@ struct DeallocOpLowering : public LLVMLegalizationPattern { op, ArrayRef(), rewriter.getSymbolRefAttr(freeFunc), casted); return matchSuccess(); } + + bool useAlloca; }; struct MemRefCastOpLowering : public LLVMLegalizationPattern { @@ -1759,7 +1802,6 @@ void mlir::populateStdToLLVMConversionPatterns( patterns.insert< AddFOpLowering, AddIOpLowering, - AllocOpLowering, AndOpLowering, BranchOpLowering, CallIndirectOpLowering, @@ -1768,7 +1810,6 @@ void mlir::populateStdToLLVMConversionPatterns( CmpIOpLowering, CondBranchOpLowering, ConstLLVMOpLowering, - DeallocOpLowering, DimOpLowering, DivFOpLowering, DivISOpLowering, @@ -1800,6 +1841,10 @@ void mlir::populateStdToLLVMConversionPatterns( ViewOpLowering, XOrOpLowering, ZeroExtendIOpLowering>(*converter.getDialect(), converter); + patterns.insert< + AllocOpLowering, + DeallocOpLowering>( + *converter.getDialect(), converter, clUseAlloca.getValue()); // clang-format on } @@ -1873,6 +1918,7 @@ struct LLVMLoweringPass : public ModulePass { // By default, the patterns are those converting Standard operations to the // LLVMIR dialect. explicit LLVMLoweringPass( + bool useAlloca = false, LLVMPatternListFiller patternListFiller = populateStdToLLVMConversionPatterns, LLVMTypeConverterMaker converterBuilder = makeStandardToLLVMTypeConverter) @@ -1911,17 +1957,25 @@ struct LLVMLoweringPass : public ModulePass { }; } // end namespace -std::unique_ptr> mlir::createLowerToLLVMPass() { - return std::make_unique(); +std::unique_ptr> +mlir::createLowerToLLVMPass(bool useAlloca) { + return std::make_unique(useAlloca); } std::unique_ptr> mlir::createLowerToLLVMPass(LLVMPatternListFiller patternListFiller, - LLVMTypeConverterMaker typeConverterMaker) { - return std::make_unique(patternListFiller, + LLVMTypeConverterMaker typeConverterMaker, + bool useAlloca) { + return std::make_unique(useAlloca, patternListFiller, typeConverterMaker); } static PassRegistration - pass("convert-std-to-llvm", "Convert scalar and vector operations from the " - "Standard to the LLVM dialect"); + pass("convert-std-to-llvm", + "Convert scalar and vector operations from the " + "Standard to the LLVM dialect", + [] { + return std::make_unique( + clUseAlloca.getValue(), populateStdToLLVMConversionPatterns, + makeStandardToLLVMTypeConverter); + });