Adding a LowerKernelBodiesToROCDL function which lowers MLIR GPU Modules to ROCDL Dialect

This commit is contained in:
Deven Desai 2020-06-29 21:40:30 +00:00
parent bf3ac62fde
commit 233d56aaec
3 changed files with 85 additions and 0 deletions

View File

@ -198,6 +198,7 @@ cc_library(
"@llvm-project//mlir:CFGTransforms",
"@llvm-project//mlir:GPUDialect",
"@llvm-project//mlir:GPUToNVVMTransforms",
"@llvm-project//mlir:GPUToROCDLTransforms",
"@llvm-project//mlir:GPUTransforms",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LLVMDialect",
@ -206,6 +207,7 @@ cc_library(
"@llvm-project//mlir:LinalgTransforms",
"@llvm-project//mlir:NVVMDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:ROCDLDialect",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:SCFToGPUPass",
"@llvm-project//mlir:SCFTransforms",

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" // from @llvm-project
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" // from @llvm-project
#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h" // from @llvm-project
#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" // from @llvm-project
#include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h" // from @llvm-project
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" // from @llvm-project
@ -26,6 +27,7 @@ limitations under the License.
#include "mlir/Dialect/GPU/Passes.h" // from @llvm-project
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project
#include "mlir/Dialect/LLVMIR/NVVMDialect.h" // from @llvm-project
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" // from @llvm-project
#include "mlir/Dialect/Linalg/Passes.h" // from @llvm-project
#include "mlir/Dialect/SCF/Passes.h" // from @llvm-project
#include "mlir/Dialect/SCF/Transforms.h" // from @llvm-project
@ -197,6 +199,85 @@ Status LowerKernelBodiesToNVVM(mlir::ModuleOp module) {
return Status::OK();
}
namespace {
/// A pass that does the final lowering to ROCDL. It collects all the patterns
/// that are currently required, currently mixing std, linalg and gpu.
class LowerToROCDLPass
: public ::mlir::PassWrapper<
LowerToROCDLPass, ::mlir::OperationPass<::mlir::gpu::GPUModuleOp>> {
void getDependentDialects(mlir::DialectRegistry& registry) const override {
registry.insert<mlir::ROCDL::ROCDLDialect, mlir::LLVM::LLVMDialect>();
}
public:
void runOnOperation() override {
::mlir::gpu::GPUModuleOp m = getOperation();
::mlir::OwningRewritePatternList patterns;
::mlir::populateGpuRewritePatterns(m.getContext(), patterns);
::mlir::applyPatternsAndFoldGreedily(m, patterns);
patterns.clear();
::mlir::LLVMTypeConverter converter(m.getContext());
::mlir::populateStdToLLVMConversionPatterns(converter, patterns);
// TODO(b/145824979) Remove linalg once sliceop is in std.
::mlir::populateLinalgToLLVMConversionPatterns(converter, patterns,
&getContext());
::mlir::populateGpuToROCDLConversionPatterns(converter, patterns);
::mlir::populateAffineToStdConversionPatterns(patterns, m.getContext());
::mlir::ConversionTarget target(getContext());
target.addIllegalDialect<::mlir::gpu::GPUDialect>();
target
.addIllegalOp<mlir::LLVM::CosOp, mlir::LLVM::ExpOp, mlir::LLVM::FAbsOp,
mlir::LLVM::FCeilOp, mlir::LLVM::LogOp,
mlir::LLVM::Log10Op, mlir::LLVM::Log2Op>();
target.addIllegalOp<mlir::FuncOp>();
target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
// TODO(csigg): Remove once we support replacing non-root ops.
target.addLegalOp<::mlir::gpu::GPUModuleOp, ::mlir::gpu::ModuleEndOp,
::mlir::gpu::YieldOp>();
if (failed(mlir::applyFullConversion(m, target, patterns))) {
signalPassFailure();
}
}
};
} // namespace
Status LowerKernelBodiesToROCDL(mlir::ModuleOp module) {
// We cannot verify as the signature of the kernel is rewritten.
::mlir::PassManager pm(module.getContext(), /*verifyPasses=*/false);
applyPassManagerCLOptions(pm);
auto enable_if_vlog_is_on = [](mlir::Pass* pass, mlir::Operation* op) {
return VLOG_IS_ON(1);
};
pm.enableIRPrinting(/*shouldPrintBeforePass=*/{},
/*shouldPrintAfterPass=*/enable_if_vlog_is_on,
/*printModuleScope=*/false,
/*printAfterOnlyOnChange=*/false,
/*out=*/llvm::dbgs());
// Rewrite kernel functions to LLVM IR.
auto& kernelPm = pm.nest<::mlir::gpu::GPUModuleOp>();
kernelPm.addPass(::mlir::createLowerToCFGPass());
kernelPm.addPass(absl::make_unique<LowerToROCDLPass>());
// Some basic cleanup.
kernelPm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass());
kernelPm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass());
// Remove all location information to prevent a debug build.
kernelPm.addPass(::mlir::createStripDebugInfoPass());
if (failed(pm.run(module))) {
return InternalError("Lowering to ROCDL IR failed.");
}
return Status::OK();
}
StatusOr<mlir::ModuleOp> ExtractKernelModule(mlir::ModuleOp module) {
auto kernelModule = ::mlir::ModuleOp::create(module.getLoc());
// TODO(b/137624192): This also needs to resolve naming conflicts.

View File

@ -36,6 +36,8 @@ Status LowerLHLOToGPU(mlir::ModuleOp module,
Status LowerKernelBodiesToNVVM(mlir::ModuleOp module);
Status LowerKernelBodiesToROCDL(mlir::ModuleOp module);
StatusOr<mlir::ModuleOp> ExtractKernelModule(mlir::ModuleOp module);
} // namespace mlir_gpu