Move the lowering to final gpu IR passes over to kernel generator and clean them up a bit.
PiperOrigin-RevId: 348615044 Change-Id: I1923f1bbce89544449f4d695eabc53a1d8b3cc0b
This commit is contained in:
parent
36803f4d62
commit
852b977596
@ -56,7 +56,6 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service/gpu:stream_executor_util",
|
||||
"//tensorflow/compiler/xla/service/gpu:target_constants",
|
||||
"//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend",
|
||||
"//tensorflow/compiler/xla/service/mlir_gpu:kernel_lowering",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/platform:cuda_libdevice_path",
|
||||
"@llvm-project//llvm:Support",
|
||||
|
@ -56,7 +56,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
|
||||
#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
|
||||
#include "tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/path.h"
|
||||
@ -228,6 +227,33 @@ Status LowerTFtoGPU(mlir::ModuleOp module, llvm::ArrayRef<uint32_t> tile_sizes,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status LowerKernelBodiesToLowLevelIr(mlir::ModuleOp module) {
|
||||
#if !defined(TENSORFLOW_USE_ROCM) && !defined(GOOGLE_CUDA)
|
||||
return InternalError(
|
||||
"Neither TENSORFLOW_USE_ROCM nor GOOGLE_CUDA are defined."
|
||||
" Did you specify either --config=rocm or --config=cuda ?");
|
||||
#endif
|
||||
mlir::PassManager pm(module.getContext());
|
||||
// We cannot verify as the signature of the kernel is rewritten.
|
||||
// pm.enableVerifier(false);
|
||||
tensorflow::applyTensorflowAndCLOptions(pm);
|
||||
auto& kernelPm = pm.nest<::mlir::gpu::GPUModuleOp>();
|
||||
kernelPm.addPass(::mlir::createLowerToCFGPass());
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
kernelPm.addPass(mlir::kernel_gen::transforms::CreateGpuKernelToRocdlPass());
|
||||
#elif GOOGLE_CUDA
|
||||
kernelPm.addPass(mlir::kernel_gen::transforms::CreateGpuKernelToNvvmPass());
|
||||
#endif
|
||||
// Remove all location information to prevent a debug build.
|
||||
pm.addPass(::mlir::createStripDebugInfoPass());
|
||||
|
||||
if (failed(pm.run(module))) {
|
||||
return InternalError("Lowering to low-level device IR failed.");
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AmendKernelLLVMIRWithStaticKnowledge(mlir::ModuleOp module) {
|
||||
mlir::PassManager pm(module.getContext());
|
||||
applyTensorflowAndCLOptions(pm);
|
||||
@ -290,17 +316,7 @@ StatusOr<mlir::OwningModuleRef> GenerateKernelForTfCode(
|
||||
mlir::OwningModuleRef module = mlir::parseSourceString(tf_code, &context);
|
||||
TF_RETURN_IF_ERROR(LowerTFtoGPU(module.get(), tile_sizes, unroll_factors,
|
||||
embed_memref_prints));
|
||||
#if !defined(TENSORFLOW_USE_ROCM) && !defined(GOOGLE_CUDA)
|
||||
return InternalError(
|
||||
"Neither TENSORFLOW_USE_ROCM nor GOOGLE_CUDA are defined."
|
||||
" Did you specify either --config=rocm or --config=cuda ?");
|
||||
#endif
|
||||
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
TF_RETURN_IF_ERROR(xla::mlir_gpu::LowerKernelBodiesToROCDL(module.get()));
|
||||
#elif GOOGLE_CUDA
|
||||
TF_RETURN_IF_ERROR(xla::mlir_gpu::LowerKernelBodiesToNVVM(module.get()));
|
||||
#endif
|
||||
TF_RETURN_IF_ERROR(LowerKernelBodiesToLowLevelIr(module.get()));
|
||||
TF_RETURN_IF_ERROR(AmendKernelLLVMIRWithStaticKnowledge(module.get()));
|
||||
TF_RETURN_IF_ERROR(GenerateDeviceCode(module.get(), kGpuBinaryAttrName,
|
||||
architectures, generate_fatbin,
|
||||
|
@ -78,6 +78,7 @@ cc_library(
|
||||
"embed_tf_framework_pass.cc",
|
||||
"fuse_inner_parallel_loops_pass.cc",
|
||||
"gpu_kernel_to_blob_pass.cc",
|
||||
"kernel_lowering_passes.cc",
|
||||
"map_parallel_loops_to_gpu.cc",
|
||||
"parallel_loops_to_sequential.cc",
|
||||
"same_shape_propagation.cc",
|
||||
@ -88,50 +89,54 @@ cc_library(
|
||||
hdrs = ["passes.h"],
|
||||
copts = if_cuda_is_configured(["-DGOOGLE_CUDA=1"]) + if_rocm_is_configured(["-DTENSORFLOW_USE_ROCM=1"]),
|
||||
deps = [
|
||||
"@llvm-project//mlir:Affine",
|
||||
"@llvm-project//mlir:TensorDialect",
|
||||
"@llvm-project//mlir:TensorTransforms",
|
||||
"//tensorflow/compiler/xla/service:hlo_module_config",
|
||||
"//tensorflow/compiler/xla:debug_options_flags",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
"//tensorflow/compiler/xla/service/gpu:target_constants",
|
||||
"//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend",
|
||||
"//tensorflow/core/platform:cuda_libdevice_path",
|
||||
"//tensorflow/core:lib",
|
||||
":bufferize",
|
||||
":embed_tf_framework",
|
||||
":kernel_gen_passes_inc_gen",
|
||||
":tf_framework_legalize_to_llvm",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//llvm:TransformUtils",
|
||||
"@llvm-project//mlir:Affine",
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:GPUDialect",
|
||||
"@llvm-project//mlir:GPUTransforms",
|
||||
"@llvm-project//mlir:GPUToGPURuntimeTransforms",
|
||||
"@llvm-project//mlir:GPUToNVVMTransforms",
|
||||
"@llvm-project//mlir:GPUToROCDLTransforms",
|
||||
"@llvm-project//mlir:GPUTransforms",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:LinalgOps",
|
||||
"@llvm-project//mlir:LinalgTransforms",
|
||||
"@llvm-project//mlir:LLVMDialect",
|
||||
"@llvm-project//mlir:LLVMTransforms",
|
||||
"@llvm-project//mlir:LinalgTransforms",
|
||||
"@llvm-project//mlir:LinalgOps",
|
||||
"@llvm-project//mlir:NVVMDialect",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:ROCDLDialect",
|
||||
"@llvm-project//mlir:SCFDialect",
|
||||
"@llvm-project//mlir:Shape",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:TargetNVVMIR",
|
||||
"@llvm-project//mlir:TargetROCDLIR",
|
||||
"@llvm-project//mlir:ShapeToStandard",
|
||||
"@llvm-project//mlir:SCFToStandard",
|
||||
"@llvm-project//mlir:Shape",
|
||||
"@llvm-project//mlir:ShapeToStandard",
|
||||
"@llvm-project//mlir:ShapeTransforms",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:StandardOpsTransforms",
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:TargetNVVMIR",
|
||||
"@llvm-project//mlir:TargetROCDLIR",
|
||||
"@llvm-project//mlir:TensorDialect",
|
||||
"@llvm-project//mlir:TensorTransforms",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
"@llvm-project//llvm:TransformUtils",
|
||||
"//tensorflow/compiler/mlir/hlo",
|
||||
"//tensorflow/compiler/mlir/hlo:hlo_legalize_to_lhlo",
|
||||
"//tensorflow/compiler/mlir/hlo:lhlo",
|
||||
"//tensorflow/compiler/xla/service/gpu:stream_executor_util",
|
||||
"//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops",
|
||||
"//tensorflow/compiler/xla:debug_options_flags",
|
||||
"//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend",
|
||||
"//tensorflow/compiler/xla/service/gpu:stream_executor_util",
|
||||
"//tensorflow/compiler/xla/service/gpu:target_constants",
|
||||
"//tensorflow/compiler/xla/service:hlo_module_config",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/platform:cuda_libdevice_path",
|
||||
] + if_cuda_is_configured([
|
||||
"//tensorflow/stream_executor/gpu:asm_compiler",
|
||||
]) + if_rocm_is_configured([
|
||||
|
@ -0,0 +1,98 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" // from @llvm-project
|
||||
#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h" // from @llvm-project
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" // from @llvm-project
|
||||
#include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project
|
||||
#include "mlir/Dialect/LLVMIR/LLVMTypes.h" // from @llvm-project
|
||||
#include "mlir/Dialect/LLVMIR/NVVMDialect.h" // from @llvm-project
|
||||
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" // from @llvm-project
|
||||
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace kernel_gen {
|
||||
namespace transforms {
|
||||
|
||||
using gpu::GPUModuleOp;
|
||||
|
||||
namespace {
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
|
||||
|
||||
/// A pass that does the final lowering to NVVM. It collects all the patterns
|
||||
/// that are currently required, currently mixing std, linalg and gpu.
|
||||
class GpuKernelToNVVMPass
|
||||
: public GpuKernelToNVVMPassBase<GpuKernelToNVVMPass> {
|
||||
void getDependentDialects(mlir::DialectRegistry& registry) const override {
|
||||
registry.insert<mlir::NVVM::NVVMDialect, mlir::LLVM::LLVMDialect>();
|
||||
}
|
||||
|
||||
public:
|
||||
void runOnOperation() override {
|
||||
GPUModuleOp m = getOperation();
|
||||
|
||||
OwningRewritePatternList patterns;
|
||||
LLVMTypeConverter converter(m.getContext());
|
||||
populateStdToLLVMConversionPatterns(converter, patterns);
|
||||
populateGpuToNVVMConversionPatterns(converter, patterns);
|
||||
ConversionTarget target(getContext());
|
||||
configureGpuToNVVMConversionLegality(target);
|
||||
if (failed(mlir::applyFullConversion(m, target, std::move(patterns)))) {
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// A pass that does the final lowering to ROCDL. It collects all the patterns
|
||||
/// that are currently required, currently mixing std, linalg and gpu.
|
||||
class GpuKernelToROCDLPass
|
||||
: public GpuKernelToNVVMPassBase<GpuKernelToROCDLPass> {
|
||||
void getDependentDialects(mlir::DialectRegistry& registry) const override {
|
||||
registry.insert<mlir::ROCDL::ROCDLDialect, mlir::LLVM::LLVMDialect>();
|
||||
}
|
||||
|
||||
public:
|
||||
void runOnOperation() override {
|
||||
gpu::GPUModuleOp m = getOperation();
|
||||
|
||||
OwningRewritePatternList patterns;
|
||||
LLVMTypeConverter converter(m.getContext());
|
||||
populateStdToLLVMConversionPatterns(converter, patterns);
|
||||
populateGpuToROCDLConversionPatterns(converter, patterns);
|
||||
ConversionTarget target(getContext());
|
||||
configureGpuToROCDLConversionLegality(target);
|
||||
if (failed(mlir::applyFullConversion(m, target, std::move(patterns)))) {
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<GPUModuleOp> > CreateGpuKernelToNvvmPass() {
|
||||
return std::make_unique<GpuKernelToNVVMPass>();
|
||||
}
|
||||
|
||||
std::unique_ptr<OperationPass<GPUModuleOp> > CreateGpuKernelToRocdlPass() {
|
||||
return std::make_unique<GpuKernelToROCDLPass>();
|
||||
}
|
||||
|
||||
} // namespace transforms
|
||||
} // namespace kernel_gen
|
||||
} // namespace mlir
|
@ -88,6 +88,14 @@ std::unique_ptr<mlir::FunctionPass> CreateMapParallelLoopsPass();
|
||||
/// be closed from above.
|
||||
std::unique_ptr<mlir::FunctionPass> CreateFuseInnerParallelLoopsPass();
|
||||
|
||||
/// Pass that transforms gpu modules in standard dialect to NNVM.
|
||||
std::unique_ptr<OperationPass<mlir::gpu::GPUModuleOp>>
|
||||
CreateGpuKernelToNvvmPass();
|
||||
|
||||
/// Pass that transforms gpu modules in standard dialect to ROCDL.
|
||||
std::unique_ptr<OperationPass<mlir::gpu::GPUModuleOp>>
|
||||
CreateGpuKernelToRocdlPass();
|
||||
|
||||
} // namespace transforms
|
||||
|
||||
#define GEN_PASS_REGISTRATION
|
||||
|
@ -61,6 +61,16 @@ def FinalBufferizePass : Pass<"final-bufferize", "ModuleOp"> {
|
||||
let constructor = "transforms::CreateFinalBufferizePass()";
|
||||
}
|
||||
|
||||
def GpuKernelToNVVMPass : Pass<"gpu-kernel-to-nvvm", "gpu::GPUModuleOp"> {
|
||||
let summary = "Pass to transform a gpu module to nvvm.";
|
||||
let constructor = "transforms::CreateGpuKernelToNvvmPass()";
|
||||
}
|
||||
|
||||
def GpuKernelToROCDLPass : Pass<"gpu-kernel-to-rocdl", "gpu::GPUModuleOp"> {
|
||||
let summary = "Pass to transform a gpu module to rocdl.";
|
||||
let constructor = "transforms::CreateGpuKernelToRocdlPass()";
|
||||
}
|
||||
|
||||
def GpuKernelToBlobPass : Pass<"gpu-kernel-to-blob", "gpu::GPUModuleOp"> {
|
||||
let summary = "Pass to annotate GPU Module with its PTX";
|
||||
let options = [
|
||||
|
Loading…
Reference in New Issue
Block a user