diff --git a/third_party/mlir/include/mlir/Dialect/GPU/GPUOps.td b/third_party/mlir/include/mlir/Dialect/GPU/GPUOps.td index 7ef10808888..5f7bab3040f 100644 --- a/third_party/mlir/include/mlir/Dialect/GPU/GPUOps.td +++ b/third_party/mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -120,9 +120,9 @@ def GPU_GPUFuncOp : GPU_Op<"func", [FunctionLike, IsolatedFromAbove, Symbol]> { let builders = [ OpBuilder<"Builder *builder, OperationState &result, StringRef name, " - "FunctionType type, ArrayRef workgroupAttributions, " - "ArrayRef privateAttributions, " - "ArrayRef attrs"> + "FunctionType type, ArrayRef workgroupAttributions = {}, " + "ArrayRef privateAttributions = {}, " + "ArrayRef attrs = {}"> ]; let extraClassDeclaration = [{ @@ -138,6 +138,17 @@ def GPU_GPUFuncOp : GPU_Op<"func", [FunctionLike, IsolatedFromAbove, Symbol]> { return getTypeAttr().getValue().cast(); } + /// Change the type of this function in place. This is an extremely + /// dangerous operation and it is up to the caller to ensure that this is + /// legal for this function, and to restore invariants: + /// - the entry block args must be updated to match the function params. + /// - the argument/result attributes may need an update: if the new type + /// has less parameters we drop the extra attributes, if there are more + /// parameters they won't have any attributes. + // TODO(b/146349912): consider removing this function thanks to rewrite + // patterns. + void setType(FunctionType newType); + /// Returns the number of buffers located in the workgroup memory. unsigned getNumWorkgroupAttributions() { return getAttrOfType(getNumWorkgroupAttributionsAttrName()) @@ -270,11 +281,11 @@ def GPU_LaunchFuncOp : GPU_Op<"launch_func">, let skipDefaultBuilders = 1; let builders = [ - OpBuilder<"Builder *builder, OperationState &result, FuncOp kernelFunc, " + OpBuilder<"Builder *builder, OperationState &result, GPUFuncOp kernelFunc, " "Value *gridSizeX, Value *gridSizeY, Value *gridSizeZ, " "Value *blockSizeX, Value *blockSizeY, Value *blockSizeZ, " "ValueRange kernelOperands">, - OpBuilder<"Builder *builder, OperationState &result, FuncOp kernelFunc, " + OpBuilder<"Builder *builder, OperationState &result, GPUFuncOp kernelFunc, " "KernelDim3 gridSize, KernelDim3 blockSize, " "ValueRange kernelOperands"> ]; diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h index 1619a5edd89..f48a1d0b129 100644 --- a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h +++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h @@ -83,12 +83,6 @@ StringRef getEntryPointABIAttrName(); EntryPointABIAttr getEntryPointABIAttr(ArrayRef localSize, MLIRContext *context); -/// Legalizes a function as an entry function. -FuncOp lowerAsEntryFunction(FuncOp funcOp, SPIRVTypeConverter &typeConverter, - ConversionPatternRewriter &rewriter, - spirv::EntryPointABIAttr entryPointInfo, - ArrayRef argABIInfo); - /// Sets the InterfaceVarABIAttr and EntryPointABIAttr for a function and its /// arguments LogicalResult setABIAttrs(FuncOp funcOp, diff --git a/third_party/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/third_party/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index e4bdd7cb2be..f41c0c45e96 100644 --- a/third_party/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/third_party/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -489,8 +489,6 @@ struct GPUFuncOpLowering : LLVMOpLowering { } // Rewrite the original GPU function to an LLVM function. - // TODO(zinenko): there is a hack in the std->llvm lowering that promotes - // structs to pointers that probably needs to be replicated here. auto funcType = lowering.convertType(gpuFuncOp.getType()) .cast() .getPointerElementTy(); @@ -576,16 +574,51 @@ struct GPUFuncOpLowering : LLVMOpLowering { } } + // Move the region to the new function, update the entry block signature. rewriter.inlineRegionBefore(gpuFuncOp.getBody(), llvmFuncOp.getBody(), llvmFuncOp.end()); rewriter.applySignatureConversion(&llvmFuncOp.getBody(), signatureConversion); + { + // For memref-typed arguments, insert the relevant loads in the beginning + // of the block to comply with the LLVM dialect calling convention. This + // needs to be done after signature conversion to get the right types. + OpBuilder::InsertionGuard guard(rewriter); + Block &block = llvmFuncOp.front(); + rewriter.setInsertionPointToStart(&block); + + for (auto en : llvm::enumerate(gpuFuncOp.getType().getInputs())) { + if (!en.value().isa() && + !en.value().isa()) + continue; + + BlockArgument *arg = block.getArgument(en.index()); + Value *loaded = rewriter.create(loc, arg); + rewriter.replaceUsesOfBlockArgument(arg, loaded); + } + } + rewriter.eraseOp(gpuFuncOp); return matchSuccess(); } }; +struct GPUReturnOpLowering : public LLVMOpLowering { + GPUReturnOpLowering(LLVMTypeConverter &typeConverter) + : LLVMOpLowering(gpu::ReturnOp::getOperationName(), + typeConverter.getDialect()->getContext(), + typeConverter) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, operands, + ArrayRef()); + return matchSuccess(); + } +}; + /// Import the GPU Ops to NVVM Patterns. #include "GPUToNVVM.cpp.inc" @@ -632,7 +665,8 @@ void mlir::populateGpuToNVVMConversionPatterns( NVVM::BlockIdYOp, NVVM::BlockIdZOp>, GPUIndexIntrinsicOpLowering, - GPUAllReduceOpLowering, GPUFuncOpLowering>(converter); + GPUAllReduceOpLowering, GPUFuncOpLowering, GPUReturnOpLowering>( + converter); patterns.insert>(converter, "__nv_expf", "__nv_exp"); } diff --git a/third_party/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp b/third_party/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp index 2b39c0db994..a8747a7a9bf 100644 --- a/third_party/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp +++ b/third_party/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp @@ -51,21 +51,20 @@ public: ConversionPatternRewriter &rewriter) const override; }; -/// Pattern to convert a kernel function in GPU dialect (a FuncOp with the -/// attribute gpu.kernel) within a spv.module. -class KernelFnConversion final : public SPIRVOpLowering { +/// Pattern to convert a kernel function in GPU dialect within a spv.module. +class KernelFnConversion final : public SPIRVOpLowering { public: KernelFnConversion(MLIRContext *context, SPIRVTypeConverter &converter, ArrayRef workGroupSize, PatternBenefit benefit = 1) - : SPIRVOpLowering(context, converter, benefit) { + : SPIRVOpLowering(context, converter, benefit) { auto config = workGroupSize.take_front(3); workGroupSizeAsInt32.assign(config.begin(), config.end()); workGroupSizeAsInt32.resize(3, 1); } PatternMatchResult - matchAndRewrite(FuncOp funcOp, ArrayRef operands, + matchAndRewrite(gpu::GPUFuncOp funcOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; private: @@ -96,6 +95,17 @@ public: ConversionPatternRewriter &rewriter) const override; }; +/// Pattern to convert a gpu.return into a SPIR-V return. +// TODO: This can go to DRR when GPU return has operands. +class GPUReturnOpConversion final : public SPIRVOpLowering { +public: + using SPIRVOpLowering::SPIRVOpLowering; + + PatternMatchResult + matchAndRewrite(gpu::ReturnOp returnOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; + } // namespace //===----------------------------------------------------------------------===// @@ -204,11 +214,58 @@ PatternMatchResult LaunchConfigConversion::matchAndRewrite( } //===----------------------------------------------------------------------===// -// FuncOp with gpu.kernel attribute. +// GPUFuncOp //===----------------------------------------------------------------------===// +// Legalizes a GPU function as an entry SPIR-V function. +static FuncOp +lowerAsEntryFunction(gpu::GPUFuncOp funcOp, SPIRVTypeConverter &typeConverter, + ConversionPatternRewriter &rewriter, + spirv::EntryPointABIAttr entryPointInfo, + ArrayRef argABIInfo) { + auto fnType = funcOp.getType(); + if (fnType.getNumResults()) { + funcOp.emitError("SPIR-V lowering only supports entry functions" + "with no return values right now"); + return nullptr; + } + if (fnType.getNumInputs() != argABIInfo.size()) { + funcOp.emitError( + "lowering as entry functions requires ABI info for all arguments"); + return nullptr; + } + // For entry functions need to make the signature void(void). Compute the + // replacement value for all arguments and replace all uses. + TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs()); + { + for (auto argType : enumerate(funcOp.getType().getInputs())) { + auto convertedType = typeConverter.convertType(argType.value()); + signatureConverter.addInputs(argType.index(), convertedType); + } + } + auto newFuncOp = rewriter.create( + funcOp.getLoc(), funcOp.getName(), + rewriter.getFunctionType(signatureConverter.getConvertedTypes(), + llvm::None), + ArrayRef()); + for (const auto &namedAttr : funcOp.getAttrs()) { + if (namedAttr.first.is(impl::getTypeAttrName()) || + namedAttr.first.is(SymbolTable::getSymbolAttrName())) + continue; + newFuncOp.setAttr(namedAttr.first, namedAttr.second); + } + rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), + newFuncOp.end()); + rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter); + rewriter.eraseOp(funcOp); + + spirv::setABIAttrs(newFuncOp, entryPointInfo, argABIInfo); + return newFuncOp; +} + PatternMatchResult -KernelFnConversion::matchAndRewrite(FuncOp funcOp, ArrayRef operands, +KernelFnConversion::matchAndRewrite(gpu::GPUFuncOp funcOp, + ArrayRef operands, ConversionPatternRewriter &rewriter) const { if (!gpu::GPUDialect::isKernel(funcOp)) { return matchFailure(); @@ -223,8 +280,8 @@ KernelFnConversion::matchAndRewrite(FuncOp funcOp, ArrayRef operands, auto context = rewriter.getContext(); auto entryPointAttr = spirv::getEntryPointABIAttr(workGroupSizeAsInt32, context); - FuncOp newFuncOp = spirv::lowerAsEntryFunction( - funcOp, typeConverter, rewriter, entryPointAttr, argABI); + FuncOp newFuncOp = lowerAsEntryFunction(funcOp, typeConverter, rewriter, + entryPointAttr, argABI); if (!newFuncOp) { return matchFailure(); } @@ -274,6 +331,20 @@ PatternMatchResult KernelModuleTerminatorConversion::matchAndRewrite( return matchSuccess(); } +//===----------------------------------------------------------------------===// +// GPU return inside kernel functions to SPIR-V return. +//===----------------------------------------------------------------------===// + +PatternMatchResult GPUReturnOpConversion::matchAndRewrite( + gpu::ReturnOp returnOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + if (!operands.empty()) + return matchFailure(); + + rewriter.replaceOpWithNewOp(returnOp); + return matchSuccess(); +} + //===----------------------------------------------------------------------===// // GPU To SPIRV Patterns. //===----------------------------------------------------------------------===// @@ -285,7 +356,8 @@ void populateGPUToSPIRVPatterns(MLIRContext *context, ArrayRef workGroupSize) { patterns.insert(context, typeConverter, workGroupSize); patterns.insert< - ForOpConversion, KernelModuleConversion, KernelModuleTerminatorConversion, + GPUReturnOpConversion, ForOpConversion, KernelModuleConversion, + KernelModuleTerminatorConversion, LaunchConfigConversion, LaunchConfigConversion, LaunchConfigConversion, diff --git a/third_party/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/third_party/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index 1f48d6d47e4..46a568caac5 100644 --- a/third_party/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/third_party/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -94,9 +94,9 @@ LogicalResult GPUDialect::verifyOperationAttribute(Operation *op, // Check that `launch_func` refers to a well-formed kernel function. StringRef kernelName = launchOp.kernel(); Operation *kernelFunc = kernelModule.lookupSymbol(kernelName); - auto kernelStdFunction = dyn_cast_or_null<::mlir::FuncOp>(kernelFunc); + auto kernelGPUFunction = dyn_cast_or_null(kernelFunc); auto kernelLLVMFunction = dyn_cast_or_null(kernelFunc); - if (!kernelStdFunction && !kernelLLVMFunction) + if (!kernelGPUFunction && !kernelLLVMFunction) return launchOp.emitOpError("kernel function '") << kernelName << "' is undefined"; if (!kernelFunc->getAttrOfType( @@ -107,7 +107,7 @@ LogicalResult GPUDialect::verifyOperationAttribute(Operation *op, unsigned actualNumArguments = launchOp.getNumKernelOperands(); unsigned expectedNumArguments = kernelLLVMFunction ? kernelLLVMFunction.getNumArguments() - : kernelStdFunction.getNumArguments(); + : kernelGPUFunction.getNumArguments(); if (expectedNumArguments != actualNumArguments) return launchOp.emitOpError("got ") << actualNumArguments << " kernel operands but expected " @@ -488,7 +488,7 @@ void LaunchOp::getCanonicalizationPatterns(OwningRewritePatternList &results, //===----------------------------------------------------------------------===// void LaunchFuncOp::build(Builder *builder, OperationState &result, - ::mlir::FuncOp kernelFunc, Value *gridSizeX, + GPUFuncOp kernelFunc, Value *gridSizeX, Value *gridSizeY, Value *gridSizeZ, Value *blockSizeX, Value *blockSizeY, Value *blockSizeZ, ValueRange kernelOperands) { @@ -505,7 +505,7 @@ void LaunchFuncOp::build(Builder *builder, OperationState &result, } void LaunchFuncOp::build(Builder *builder, OperationState &result, - ::mlir::FuncOp kernelFunc, KernelDim3 gridSize, + GPUFuncOp kernelFunc, KernelDim3 gridSize, KernelDim3 blockSize, ValueRange kernelOperands) { build(builder, result, kernelFunc, gridSize.x, gridSize.y, gridSize.z, blockSize.x, blockSize.y, blockSize.z, kernelOperands); @@ -718,6 +718,18 @@ void printGPUFuncOp(OpAsmPrinter &p, GPUFuncOp op) { p.printRegion(op.getBody(), /*printEntryBlockArgs=*/false); } +void GPUFuncOp::setType(FunctionType newType) { + auto oldType = getType(); + assert(newType.getNumResults() == oldType.getNumResults() && + "unimplemented: changes to the number of results"); + + SmallVector nameBuf; + for (int i = newType.getNumInputs(), e = oldType.getNumInputs(); i < e; i++) + removeAttr(getArgAttrName(i, nameBuf)); + + setAttr(getTypeAttrName(), TypeAttr::get(newType)); +} + /// Hook for FunctionLike verifier. LogicalResult GPUFuncOp::verifyType() { Type type = getTypeAttr().getValue(); diff --git a/third_party/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/third_party/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp index b466cc280df..416a37b3270 100644 --- a/third_party/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp +++ b/third_party/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp @@ -39,19 +39,21 @@ static void createForAllDimensions(OpBuilder &builder, Location loc, } } -// Add operations generating block/thread ids and gird/block dimensions at the -// beginning of `kernelFunc` and replace uses of the respective function args. -static void injectGpuIndexOperations(Location loc, FuncOp kernelFunc) { - OpBuilder OpBuilder(kernelFunc.getBody()); +// Add operations generating block/thread ids and grid/block dimensions at the +// beginning of the `body` region and replace uses of the respective function +// arguments. +static void injectGpuIndexOperations(Location loc, Region &body) { + OpBuilder builder(loc->getContext()); + Block &firstBlock = body.front(); + builder.setInsertionPointToStart(&firstBlock); SmallVector indexOps; - createForAllDimensions(OpBuilder, loc, indexOps); - createForAllDimensions(OpBuilder, loc, indexOps); - createForAllDimensions(OpBuilder, loc, indexOps); - createForAllDimensions(OpBuilder, loc, indexOps); + createForAllDimensions(builder, loc, indexOps); + createForAllDimensions(builder, loc, indexOps); + createForAllDimensions(builder, loc, indexOps); + createForAllDimensions(builder, loc, indexOps); // Replace the leading 12 function args with the respective thread/block index // operations. Iterate backwards since args are erased and indices change. for (int i = 11; i >= 0; --i) { - auto &firstBlock = kernelFunc.front(); firstBlock.getArgument(i)->replaceAllUsesWith(indexOps[i]); firstBlock.eraseArgument(i); } @@ -63,7 +65,7 @@ static bool isInliningBeneficiary(Operation *op) { // Move arguments of the given kernel function into the function if this reduces // the number of kernel arguments. -static gpu::LaunchFuncOp inlineBeneficiaryOps(FuncOp kernelFunc, +static gpu::LaunchFuncOp inlineBeneficiaryOps(gpu::GPUFuncOp kernelFunc, gpu::LaunchFuncOp launch) { OpBuilder kernelBuilder(kernelFunc.getBody()); auto &firstBlock = kernelFunc.getBody().front(); @@ -107,31 +109,30 @@ static gpu::LaunchFuncOp inlineBeneficiaryOps(FuncOp kernelFunc, // Outline the `gpu.launch` operation body into a kernel function. Replace // `gpu.return` operations by `std.return` in the generated function. -static FuncOp outlineKernelFunc(gpu::LaunchOp launchOp) { +static gpu::GPUFuncOp outlineKernelFunc(gpu::LaunchOp launchOp) { Location loc = launchOp.getLoc(); + // Create a builder with no insertion point, insertion will happen separately + // due to symbol table manipulation. + OpBuilder builder(launchOp.getContext()); + SmallVector kernelOperandTypes(launchOp.getKernelOperandTypes()); FunctionType type = FunctionType::get(kernelOperandTypes, {}, launchOp.getContext()); std::string kernelFuncName = Twine(launchOp.getParentOfType().getName(), "_kernel").str(); - FuncOp outlinedFunc = FuncOp::create(loc, kernelFuncName, type); - outlinedFunc.getBody().takeBody(launchOp.body()); - Builder builder(launchOp.getContext()); + auto outlinedFunc = builder.create(loc, kernelFuncName, type); outlinedFunc.setAttr(gpu::GPUDialect::getKernelFuncAttrName(), builder.getUnitAttr()); - injectGpuIndexOperations(loc, outlinedFunc); - outlinedFunc.walk([](gpu::ReturnOp op) { - OpBuilder replacer(op); - replacer.create(op.getLoc()); - op.erase(); - }); + outlinedFunc.body().takeBody(launchOp.body()); + injectGpuIndexOperations(loc, outlinedFunc.body()); return outlinedFunc; } // Replace `gpu.launch` operations with an `gpu.launch_func` operation launching // `kernelFunc`. The kernel func contains the body of the `gpu.launch` with // constant region arguments inlined. -static void convertToLaunchFuncOp(gpu::LaunchOp &launchOp, FuncOp kernelFunc) { +static void convertToLaunchFuncOp(gpu::LaunchOp &launchOp, + gpu::GPUFuncOp kernelFunc) { OpBuilder builder(launchOp); auto launchFuncOp = builder.create( launchOp.getLoc(), kernelFunc, launchOp.getGridSizeOperandValues(), @@ -160,7 +161,7 @@ public: // Insert just after the function. Block::iterator insertPt(func.getOperation()->getNextNode()); func.walk([&](gpu::LaunchOp op) { - FuncOp outlinedFunc = outlineKernelFunc(op); + gpu::GPUFuncOp outlinedFunc = outlineKernelFunc(op); // Create nested module and insert outlinedFunc. The module will // originally get the same name as the function, but may be renamed on @@ -183,7 +184,7 @@ public: private: // Returns a module containing kernelFunc and all callees (recursive). - ModuleOp createKernelModule(FuncOp kernelFunc, + ModuleOp createKernelModule(gpu::GPUFuncOp kernelFunc, const SymbolTable &parentSymbolTable) { auto context = getModule().getContext(); Builder builder(context); diff --git a/third_party/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/third_party/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp index 67c036dbcf9..1e68b49c48b 100644 --- a/third_party/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp +++ b/third_party/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp @@ -249,43 +249,6 @@ Value *mlir::spirv::getBuiltinVariableValue(Operation *op, // Entry Function signature Conversion //===----------------------------------------------------------------------===// -FuncOp mlir::spirv::lowerAsEntryFunction( - FuncOp funcOp, SPIRVTypeConverter &typeConverter, - ConversionPatternRewriter &rewriter, - spirv::EntryPointABIAttr entryPointInfo, - ArrayRef argABIInfo) { - auto fnType = funcOp.getType(); - if (fnType.getNumResults()) { - funcOp.emitError("SPIR-V lowering only supports entry functions" - "with no return values right now"); - return nullptr; - } - if (fnType.getNumInputs() != argABIInfo.size()) { - funcOp.emitError( - "lowering as entry functions requires ABI info for all arguments"); - return nullptr; - } - // For entry functions need to make the signature void(void). Compute the - // replacement value for all arguments and replace all uses. - TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs()); - { - for (auto argType : enumerate(funcOp.getType().getInputs())) { - auto convertedType = typeConverter.convertType(argType.value()); - signatureConverter.addInputs(argType.index(), convertedType); - } - } - auto newFuncOp = rewriter.cloneWithoutRegions(funcOp); - rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), - newFuncOp.end()); - newFuncOp.setType(rewriter.getFunctionType( - signatureConverter.getConvertedTypes(), llvm::None)); - rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter); - rewriter.eraseOp(funcOp); - - spirv::setABIAttrs(newFuncOp, entryPointInfo, argABIInfo); - return newFuncOp; -} - LogicalResult mlir::spirv::setABIAttrs(FuncOp funcOp, spirv::EntryPointABIAttr entryPointInfo, ArrayRef argABIInfo) {