Add lowering for module with gpu.kernel_module attribute.
The existing GPU to SPIR-V lowering created a spv.module for every function with gpu.kernel attribute. A better approach is to lower the module that the function lives in (which has the attribute gpu.kernel_module) to a spv.module operation. This better captures the host-device separation modeled by GPU dialect and simplifies the lowering as well. PiperOrigin-RevId: 284574688 Change-Id: Ibd37354595e0171693ebe3e083d517c21fff8192
This commit is contained in:
parent
c164c155c4
commit
6bf9041b35
1
third_party/mlir/BUILD
vendored
1
third_party/mlir/BUILD
vendored
@ -760,6 +760,7 @@ cc_library(
|
|||||||
includes = ["include"],
|
includes = ["include"],
|
||||||
deps = [
|
deps = [
|
||||||
":GPUDialect",
|
":GPUDialect",
|
||||||
|
":IR",
|
||||||
":LoopOps",
|
":LoopOps",
|
||||||
":Pass",
|
":Pass",
|
||||||
":SPIRVDialect",
|
":SPIRVDialect",
|
||||||
|
@ -329,13 +329,17 @@ def SPV_ModuleOp : SPV_Op<"module",
|
|||||||
|
|
||||||
let regions = (region SizedRegion<1>:$body);
|
let regions = (region SizedRegion<1>:$body);
|
||||||
|
|
||||||
let builders = [OpBuilder<"Builder *, OperationState &state">,
|
let builders =
|
||||||
OpBuilder<[{Builder *, OperationState &state,
|
[OpBuilder<"Builder *, OperationState &state">,
|
||||||
IntegerAttr addressing_model,
|
OpBuilder<[{Builder *, OperationState &state,
|
||||||
IntegerAttr memory_model,
|
IntegerAttr addressing_model,
|
||||||
/*optional*/ArrayAttr capabilities = nullptr,
|
IntegerAttr memory_model}]>,
|
||||||
/*optional*/ArrayAttr extensions = nullptr,
|
OpBuilder<[{Builder *, OperationState &state,
|
||||||
/*optional*/ArrayAttr extended_instruction_sets = nullptr}]>];
|
spirv::AddressingModel addressing_model,
|
||||||
|
spirv::MemoryModel memory_model,
|
||||||
|
/*optional*/ ArrayRef<spirv::Capability> capabilities = {},
|
||||||
|
/*optional*/ ArrayRef<spirv::Extension> extensions = {},
|
||||||
|
/*optional*/ ArrayAttr extended_instruction_sets = nullptr}]>];
|
||||||
|
|
||||||
// We need to ensure the block inside the region is properly terminated;
|
// We need to ensure the block inside the region is properly terminated;
|
||||||
// the auto-generated builders do not guarantee that.
|
// the auto-generated builders do not guarantee that.
|
||||||
|
@ -23,6 +23,7 @@
|
|||||||
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
|
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
|
||||||
#include "mlir/Dialect/SPIRV/SPIRVLowering.h"
|
#include "mlir/Dialect/SPIRV/SPIRVLowering.h"
|
||||||
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
|
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
|
||||||
|
#include "mlir/IR/Module.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
@ -71,8 +72,36 @@ private:
|
|||||||
SmallVector<int32_t, 3> workGroupSizeAsInt32;
|
SmallVector<int32_t, 3> workGroupSizeAsInt32;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Pattern to convert a module with gpu.kernel_module attribute to a
|
||||||
|
/// spv.module.
|
||||||
|
class KernelModuleConversion final : public SPIRVOpLowering<ModuleOp> {
|
||||||
|
public:
|
||||||
|
using SPIRVOpLowering<ModuleOp>::SPIRVOpLowering;
|
||||||
|
|
||||||
|
PatternMatchResult
|
||||||
|
matchAndRewrite(ModuleOp moduleOp, ArrayRef<Value *> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Pattern to convert a module terminator op to a terminator of spv.module op.
|
||||||
|
// TODO: Move this into DRR, but that requires ModuleTerminatorOp to be defined
|
||||||
|
// in ODS.
|
||||||
|
class KernelModuleTerminatorConversion final
|
||||||
|
: public SPIRVOpLowering<ModuleTerminatorOp> {
|
||||||
|
public:
|
||||||
|
using SPIRVOpLowering<ModuleTerminatorOp>::SPIRVOpLowering;
|
||||||
|
|
||||||
|
PatternMatchResult
|
||||||
|
matchAndRewrite(ModuleTerminatorOp terminatorOp, ArrayRef<Value *> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override;
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// loop::ForOp.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
PatternMatchResult
|
PatternMatchResult
|
||||||
ForOpConversion::matchAndRewrite(loop::ForOp forOp, ArrayRef<Value *> operands,
|
ForOpConversion::matchAndRewrite(loop::ForOp forOp, ArrayRef<Value *> operands,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
@ -142,6 +171,10 @@ ForOpConversion::matchAndRewrite(loop::ForOp forOp, ArrayRef<Value *> operands,
|
|||||||
return matchSuccess();
|
return matchSuccess();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Builtins.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
template <typename SourceOp, spirv::BuiltIn builtin>
|
template <typename SourceOp, spirv::BuiltIn builtin>
|
||||||
PatternMatchResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
|
PatternMatchResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
|
||||||
SourceOp op, ArrayRef<Value *> operands,
|
SourceOp op, ArrayRef<Value *> operands,
|
||||||
@ -170,6 +203,10 @@ PatternMatchResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
|
|||||||
return this->matchSuccess();
|
return this->matchSuccess();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// FuncOp with gpu.kernel attribute.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
PatternMatchResult
|
PatternMatchResult
|
||||||
KernelFnConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value *> operands,
|
KernelFnConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value *> operands,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
@ -196,6 +233,51 @@ KernelFnConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value *> operands,
|
|||||||
return matchSuccess();
|
return matchSuccess();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ModuleOp with gpu.kernel_module.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
PatternMatchResult KernelModuleConversion::matchAndRewrite(
|
||||||
|
ModuleOp moduleOp, ArrayRef<Value *> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
if (!moduleOp.getAttrOfType<UnitAttr>(
|
||||||
|
gpu::GPUDialect::getKernelModuleAttrName())) {
|
||||||
|
return matchFailure();
|
||||||
|
}
|
||||||
|
// TODO : Generalize this to account for different extensions,
|
||||||
|
// capabilities, extended_instruction_sets, other addressing models
|
||||||
|
// and memory models.
|
||||||
|
auto spvModule = rewriter.create<spirv::ModuleOp>(
|
||||||
|
moduleOp.getLoc(), spirv::AddressingModel::Logical,
|
||||||
|
spirv::MemoryModel::GLSL450, spirv::Capability::Shader,
|
||||||
|
spirv::Extension::SPV_KHR_storage_buffer_storage_class);
|
||||||
|
// Move the region from the module op into the SPIR-V module.
|
||||||
|
Region &spvModuleRegion = spvModule.getOperation()->getRegion(0);
|
||||||
|
rewriter.inlineRegionBefore(moduleOp.getBodyRegion(), spvModuleRegion,
|
||||||
|
spvModuleRegion.begin());
|
||||||
|
// The spv.module build method adds a block with a terminator. Remove that
|
||||||
|
// block. The terminator of the module op in the remaining block will be
|
||||||
|
// legalized later.
|
||||||
|
spvModuleRegion.back().erase();
|
||||||
|
rewriter.eraseOp(moduleOp);
|
||||||
|
return matchSuccess();
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ModuleTerminatorOp for gpu.kernel_module.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
PatternMatchResult KernelModuleTerminatorConversion::matchAndRewrite(
|
||||||
|
ModuleTerminatorOp terminatorOp, ArrayRef<Value *> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
rewriter.replaceOpWithNewOp<spirv::ModuleEndOp>(terminatorOp);
|
||||||
|
return matchSuccess();
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// GPU To SPIRV Patterns.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
void populateGPUToSPIRVPatterns(MLIRContext *context,
|
void populateGPUToSPIRVPatterns(MLIRContext *context,
|
||||||
SPIRVTypeConverter &typeConverter,
|
SPIRVTypeConverter &typeConverter,
|
||||||
@ -203,7 +285,7 @@ void populateGPUToSPIRVPatterns(MLIRContext *context,
|
|||||||
ArrayRef<int64_t> workGroupSize) {
|
ArrayRef<int64_t> workGroupSize) {
|
||||||
patterns.insert<KernelFnConversion>(context, typeConverter, workGroupSize);
|
patterns.insert<KernelFnConversion>(context, typeConverter, workGroupSize);
|
||||||
patterns.insert<
|
patterns.insert<
|
||||||
ForOpConversion,
|
ForOpConversion, KernelModuleConversion, KernelModuleTerminatorConversion,
|
||||||
LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>,
|
LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>,
|
||||||
LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
|
LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
|
||||||
LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
|
LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
|
||||||
|
@ -67,34 +67,19 @@ void GPUToSPIRVPass::runOnModule() {
|
|||||||
auto context = &getContext();
|
auto context = &getContext();
|
||||||
auto module = getModule();
|
auto module = getModule();
|
||||||
|
|
||||||
SmallVector<Operation *, 4> spirvModules;
|
SmallVector<Operation *, 1> kernelModules;
|
||||||
module.walk([&module, &spirvModules](FuncOp funcOp) {
|
OpBuilder builder(context);
|
||||||
if (!gpu::GPUDialect::isKernel(funcOp)) {
|
module.walk([&builder, &kernelModules](ModuleOp moduleOp) {
|
||||||
return;
|
if (moduleOp.getAttrOfType<UnitAttr>(
|
||||||
|
gpu::GPUDialect::getKernelModuleAttrName())) {
|
||||||
|
// For each kernel module (should be only 1 for now, but that is not a
|
||||||
|
// requirement here), clone the module for conversion because the
|
||||||
|
// gpu.launch function still needs the kernel module.
|
||||||
|
builder.setInsertionPoint(moduleOp.getOperation());
|
||||||
|
kernelModules.push_back(builder.clone(*moduleOp.getOperation()));
|
||||||
}
|
}
|
||||||
OpBuilder builder(funcOp.getOperation());
|
|
||||||
// Create a new spirv::ModuleOp for this function, and clone the
|
|
||||||
// function into it.
|
|
||||||
// TODO : Generalize this to account for different extensions,
|
|
||||||
// capabilities, extended_instruction_sets, other addressing models
|
|
||||||
// and memory models.
|
|
||||||
auto spvModule = builder.create<spirv::ModuleOp>(
|
|
||||||
funcOp.getLoc(),
|
|
||||||
builder.getI32IntegerAttr(
|
|
||||||
static_cast<int32_t>(spirv::AddressingModel::Logical)),
|
|
||||||
builder.getI32IntegerAttr(
|
|
||||||
static_cast<int32_t>(spirv::MemoryModel::GLSL450)),
|
|
||||||
builder.getStrArrayAttr(
|
|
||||||
spirv::stringifyCapability(spirv::Capability::Shader)),
|
|
||||||
builder.getStrArrayAttr(spirv::stringifyExtension(
|
|
||||||
spirv::Extension::SPV_KHR_storage_buffer_storage_class)));
|
|
||||||
// Hardwire the capability to be Shader.
|
|
||||||
OpBuilder moduleBuilder(spvModule.getOperation()->getRegion(0));
|
|
||||||
moduleBuilder.clone(*funcOp.getOperation());
|
|
||||||
spirvModules.push_back(spvModule);
|
|
||||||
});
|
});
|
||||||
|
|
||||||
/// Dialect conversion to lower the functions with the spirv::ModuleOps.
|
|
||||||
SPIRVTypeConverter typeConverter;
|
SPIRVTypeConverter typeConverter;
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
populateGPUToSPIRVPatterns(context, typeConverter, patterns, workGroupSize);
|
populateGPUToSPIRVPatterns(context, typeConverter, patterns, workGroupSize);
|
||||||
@ -105,7 +90,7 @@ void GPUToSPIRVPass::runOnModule() {
|
|||||||
target.addDynamicallyLegalOp<FuncOp>(
|
target.addDynamicallyLegalOp<FuncOp>(
|
||||||
[&](FuncOp op) { return typeConverter.isSignatureLegal(op.getType()); });
|
[&](FuncOp op) { return typeConverter.isSignatureLegal(op.getType()); });
|
||||||
|
|
||||||
if (failed(applyFullConversion(spirvModules, target, patterns,
|
if (failed(applyFullConversion(kernelModules, target, patterns,
|
||||||
&typeConverter))) {
|
&typeConverter))) {
|
||||||
return signalPassFailure();
|
return signalPassFailure();
|
||||||
}
|
}
|
||||||
|
@ -286,7 +286,7 @@ FuncOp mlir::spirv::lowerAsEntryFunction(
|
|||||||
newFuncOp.setType(rewriter.getFunctionType(
|
newFuncOp.setType(rewriter.getFunctionType(
|
||||||
signatureConverter.getConvertedTypes(), llvm::None));
|
signatureConverter.getConvertedTypes(), llvm::None));
|
||||||
rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter);
|
rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter);
|
||||||
rewriter.replaceOp(funcOp.getOperation(), llvm::None);
|
rewriter.eraseOp(funcOp);
|
||||||
|
|
||||||
// Set the attributes for argument and the function.
|
// Set the attributes for argument and the function.
|
||||||
StringRef argABIAttrName = spirv::getInterfaceVarABIAttrName();
|
StringRef argABIAttrName = spirv::getInterfaceVarABIAttrName();
|
||||||
|
49
third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
vendored
49
third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
vendored
@ -75,6 +75,21 @@ static LogicalResult extractValueFromConstOp(Operation *op,
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename Ty>
|
||||||
|
static ArrayAttr
|
||||||
|
getStrArrayAttrForEnumList(Builder &builder, ArrayRef<Ty> enumValues,
|
||||||
|
llvm::function_ref<StringRef(Ty)> stringifyFn) {
|
||||||
|
if (enumValues.empty()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
SmallVector<StringRef, 1> enumValStrs;
|
||||||
|
enumValStrs.reserve(enumValues.size());
|
||||||
|
for (auto val : enumValues) {
|
||||||
|
enumValStrs.emplace_back(stringifyFn(val));
|
||||||
|
}
|
||||||
|
return builder.getStrArrayAttr(enumValStrs);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename EnumClass>
|
template <typename EnumClass>
|
||||||
static ParseResult
|
static ParseResult
|
||||||
parseEnumAttribute(EnumClass &value, OpAsmParser &parser,
|
parseEnumAttribute(EnumClass &value, OpAsmParser &parser,
|
||||||
@ -2039,20 +2054,38 @@ void spirv::ModuleOp::build(Builder *builder, OperationState &state) {
|
|||||||
ensureTerminator(*state.addRegion(), *builder, state.location);
|
ensureTerminator(*state.addRegion(), *builder, state.location);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(ravishankarm): This is only here for resolving some dependency outside
|
||||||
|
// of mlir. Remove once it is done.
|
||||||
void spirv::ModuleOp::build(Builder *builder, OperationState &state,
|
void spirv::ModuleOp::build(Builder *builder, OperationState &state,
|
||||||
IntegerAttr addressing_model,
|
IntegerAttr addressing_model,
|
||||||
IntegerAttr memory_model, ArrayAttr capabilities,
|
IntegerAttr memory_model) {
|
||||||
ArrayAttr extensions,
|
|
||||||
ArrayAttr extended_instruction_sets) {
|
|
||||||
state.addAttribute("addressing_model", addressing_model);
|
state.addAttribute("addressing_model", addressing_model);
|
||||||
state.addAttribute("memory_model", memory_model);
|
state.addAttribute("memory_model", memory_model);
|
||||||
if (capabilities)
|
build(builder, state);
|
||||||
state.addAttribute("capabilities", capabilities);
|
}
|
||||||
if (extensions)
|
|
||||||
state.addAttribute("extensions", extensions);
|
void spirv::ModuleOp::build(Builder *builder, OperationState &state,
|
||||||
|
spirv::AddressingModel addressing_model,
|
||||||
|
spirv::MemoryModel memory_model,
|
||||||
|
ArrayRef<spirv::Capability> capabilities,
|
||||||
|
ArrayRef<spirv::Extension> extensions,
|
||||||
|
ArrayAttr extended_instruction_sets) {
|
||||||
|
state.addAttribute(
|
||||||
|
"addressing_model",
|
||||||
|
builder->getI32IntegerAttr(static_cast<int32_t>(addressing_model)));
|
||||||
|
state.addAttribute("memory_model", builder->getI32IntegerAttr(
|
||||||
|
static_cast<int32_t>(memory_model)));
|
||||||
|
if (!capabilities.empty())
|
||||||
|
state.addAttribute("capabilities",
|
||||||
|
getStrArrayAttrForEnumList<spirv::Capability>(
|
||||||
|
*builder, capabilities, spirv::stringifyCapability));
|
||||||
|
if (!extensions.empty())
|
||||||
|
state.addAttribute("extensions",
|
||||||
|
getStrArrayAttrForEnumList<spirv::Extension>(
|
||||||
|
*builder, extensions, spirv::stringifyExtension));
|
||||||
if (extended_instruction_sets)
|
if (extended_instruction_sets)
|
||||||
state.addAttribute("extended_instruction_sets", extended_instruction_sets);
|
state.addAttribute("extended_instruction_sets", extended_instruction_sets);
|
||||||
ensureTerminator(*state.addRegion(), *builder, state.location);
|
build(builder, state);
|
||||||
}
|
}
|
||||||
|
|
||||||
static ParseResult parseModuleOp(OpAsmParser &parser, OperationState &state) {
|
static ParseResult parseModuleOp(OpAsmParser &parser, OperationState &state) {
|
||||||
|
Loading…
Reference in New Issue
Block a user