Adding a LowerKernelBodiesToROCDL function which lowers MLIR GPU Modules to ROCDL Dialect
This commit is contained in:
parent
bf3ac62fde
commit
233d56aaec
@ -198,6 +198,7 @@ cc_library(
|
||||
"@llvm-project//mlir:CFGTransforms",
|
||||
"@llvm-project//mlir:GPUDialect",
|
||||
"@llvm-project//mlir:GPUToNVVMTransforms",
|
||||
"@llvm-project//mlir:GPUToROCDLTransforms",
|
||||
"@llvm-project//mlir:GPUTransforms",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:LLVMDialect",
|
||||
@ -206,6 +207,7 @@ cc_library(
|
||||
"@llvm-project//mlir:LinalgTransforms",
|
||||
"@llvm-project//mlir:NVVMDialect",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:ROCDLDialect",
|
||||
"@llvm-project//mlir:SCFDialect",
|
||||
"@llvm-project//mlir:SCFToGPUPass",
|
||||
"@llvm-project//mlir:SCFTransforms",
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include "absl/memory/memory.h"
|
||||
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" // from @llvm-project
|
||||
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" // from @llvm-project
|
||||
#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h" // from @llvm-project
|
||||
#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" // from @llvm-project
|
||||
#include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h" // from @llvm-project
|
||||
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" // from @llvm-project
|
||||
@ -26,6 +27,7 @@ limitations under the License.
|
||||
#include "mlir/Dialect/GPU/Passes.h" // from @llvm-project
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project
|
||||
#include "mlir/Dialect/LLVMIR/NVVMDialect.h" // from @llvm-project
|
||||
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Linalg/Passes.h" // from @llvm-project
|
||||
#include "mlir/Dialect/SCF/Passes.h" // from @llvm-project
|
||||
#include "mlir/Dialect/SCF/Transforms.h" // from @llvm-project
|
||||
@ -197,6 +199,85 @@ Status LowerKernelBodiesToNVVM(mlir::ModuleOp module) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
/// A pass that does the final lowering to ROCDL. It collects all the patterns
|
||||
/// that are currently required, currently mixing std, linalg and gpu.
|
||||
class LowerToROCDLPass
|
||||
: public ::mlir::PassWrapper<
|
||||
LowerToROCDLPass, ::mlir::OperationPass<::mlir::gpu::GPUModuleOp>> {
|
||||
void getDependentDialects(mlir::DialectRegistry& registry) const override {
|
||||
registry.insert<mlir::ROCDL::ROCDLDialect, mlir::LLVM::LLVMDialect>();
|
||||
}
|
||||
|
||||
public:
|
||||
void runOnOperation() override {
|
||||
::mlir::gpu::GPUModuleOp m = getOperation();
|
||||
|
||||
::mlir::OwningRewritePatternList patterns;
|
||||
::mlir::populateGpuRewritePatterns(m.getContext(), patterns);
|
||||
::mlir::applyPatternsAndFoldGreedily(m, patterns);
|
||||
patterns.clear();
|
||||
|
||||
::mlir::LLVMTypeConverter converter(m.getContext());
|
||||
::mlir::populateStdToLLVMConversionPatterns(converter, patterns);
|
||||
// TODO(b/145824979) Remove linalg once sliceop is in std.
|
||||
::mlir::populateLinalgToLLVMConversionPatterns(converter, patterns,
|
||||
&getContext());
|
||||
::mlir::populateGpuToROCDLConversionPatterns(converter, patterns);
|
||||
::mlir::populateAffineToStdConversionPatterns(patterns, m.getContext());
|
||||
|
||||
::mlir::ConversionTarget target(getContext());
|
||||
target.addIllegalDialect<::mlir::gpu::GPUDialect>();
|
||||
target
|
||||
.addIllegalOp<mlir::LLVM::CosOp, mlir::LLVM::ExpOp, mlir::LLVM::FAbsOp,
|
||||
mlir::LLVM::FCeilOp, mlir::LLVM::LogOp,
|
||||
mlir::LLVM::Log10Op, mlir::LLVM::Log2Op>();
|
||||
target.addIllegalOp<mlir::FuncOp>();
|
||||
target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
|
||||
target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
|
||||
// TODO(csigg): Remove once we support replacing non-root ops.
|
||||
target.addLegalOp<::mlir::gpu::GPUModuleOp, ::mlir::gpu::ModuleEndOp,
|
||||
::mlir::gpu::YieldOp>();
|
||||
if (failed(mlir::applyFullConversion(m, target, patterns))) {
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
Status LowerKernelBodiesToROCDL(mlir::ModuleOp module) {
|
||||
// We cannot verify as the signature of the kernel is rewritten.
|
||||
::mlir::PassManager pm(module.getContext(), /*verifyPasses=*/false);
|
||||
applyPassManagerCLOptions(pm);
|
||||
|
||||
auto enable_if_vlog_is_on = [](mlir::Pass* pass, mlir::Operation* op) {
|
||||
return VLOG_IS_ON(1);
|
||||
};
|
||||
pm.enableIRPrinting(/*shouldPrintBeforePass=*/{},
|
||||
/*shouldPrintAfterPass=*/enable_if_vlog_is_on,
|
||||
/*printModuleScope=*/false,
|
||||
/*printAfterOnlyOnChange=*/false,
|
||||
/*out=*/llvm::dbgs());
|
||||
|
||||
// Rewrite kernel functions to LLVM IR.
|
||||
auto& kernelPm = pm.nest<::mlir::gpu::GPUModuleOp>();
|
||||
kernelPm.addPass(::mlir::createLowerToCFGPass());
|
||||
kernelPm.addPass(absl::make_unique<LowerToROCDLPass>());
|
||||
|
||||
// Some basic cleanup.
|
||||
kernelPm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass());
|
||||
kernelPm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass());
|
||||
// Remove all location information to prevent a debug build.
|
||||
kernelPm.addPass(::mlir::createStripDebugInfoPass());
|
||||
|
||||
if (failed(pm.run(module))) {
|
||||
return InternalError("Lowering to ROCDL IR failed.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
StatusOr<mlir::ModuleOp> ExtractKernelModule(mlir::ModuleOp module) {
|
||||
auto kernelModule = ::mlir::ModuleOp::create(module.getLoc());
|
||||
// TODO(b/137624192): This also needs to resolve naming conflicts.
|
||||
|
@ -36,6 +36,8 @@ Status LowerLHLOToGPU(mlir::ModuleOp module,
|
||||
|
||||
Status LowerKernelBodiesToNVVM(mlir::ModuleOp module);
|
||||
|
||||
Status LowerKernelBodiesToROCDL(mlir::ModuleOp module);
|
||||
|
||||
StatusOr<mlir::ModuleOp> ExtractKernelModule(mlir::ModuleOp module);
|
||||
|
||||
} // namespace mlir_gpu
|
||||
|
Loading…
x
Reference in New Issue
Block a user