Rename GenerateCubinForTfCode to GenerateGpuBinaryForTfCode and add ROCm support for it
This commit is contained in:
parent
233d56aaec
commit
e7161d79b8
@ -1,5 +1,6 @@
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
|
||||
|
||||
package(
|
||||
default_visibility = [":friends"],
|
||||
@ -16,7 +17,7 @@ cc_library(
|
||||
name = "gpu_binary_creator",
|
||||
srcs = ["gpu_binary_creator.cc"],
|
||||
hdrs = ["gpu_binary_creator.h"],
|
||||
copts = if_cuda(["-DGOOGLE_CUDA=1"]),
|
||||
copts = if_cuda(["-DGOOGLE_CUDA=1"]) + if_rocm(["-DTENSORFLOW_USE_ROCM=1"]),
|
||||
deps = [
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
@ -29,6 +30,7 @@ cc_library(
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:TargetNVVMIR",
|
||||
"@llvm-project//mlir:TargetROCDLIR",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/hlo",
|
||||
@ -44,7 +46,11 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service/mlir_gpu:kernel_lowering",
|
||||
"//tensorflow/core:cuda_libdevice_path",
|
||||
"//tensorflow/core:lib",
|
||||
] + if_cuda(["//tensorflow/stream_executor/gpu:asm_compiler"]),
|
||||
] + if_cuda([
|
||||
"//tensorflow/stream_executor/gpu:asm_compiler",
|
||||
]) + if_rocm([
|
||||
"//tensorflow/core/platform:rocm_rocdl_path",
|
||||
]),
|
||||
)
|
||||
|
||||
tf_cc_binary(
|
||||
|
@ -43,6 +43,7 @@ limitations under the License.
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
#include "mlir/Target/NVVMIR.h" // from @llvm-project
|
||||
#include "mlir/Target/ROCDLIR.h" // from @llvm-project
|
||||
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
@ -60,6 +61,8 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/path.h"
|
||||
#if GOOGLE_CUDA
|
||||
#include "tensorflow/stream_executor/gpu/asm_compiler.h"
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
#include "tensorflow/core/platform/rocm_rocdl_path.h"
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
@ -250,7 +253,8 @@ Status PropagateTensorFlowABIKnowledgeToKernel(
|
||||
|
||||
} // namespace
|
||||
|
||||
StatusOr<std::vector<uint8_t>> tensorflow::kernel_gen::GenerateCubinForTfCode(
|
||||
StatusOr<std::vector<uint8_t>>
|
||||
tensorflow::kernel_gen::GenerateGpuBinaryForTfCode(
|
||||
llvm::StringRef tf_code, std::pair<int32_t, int32_t> compute_capability,
|
||||
llvm::ArrayRef<uint32_t> tile_sizes, llvm::ArrayRef<uint32_t> same_shape,
|
||||
llvm::ArrayRef<uint32_t> unroll_factors) {
|
||||
@ -267,13 +271,44 @@ StatusOr<std::vector<uint8_t>> tensorflow::kernel_gen::GenerateCubinForTfCode(
|
||||
options.use_approximations = true;
|
||||
TF_RETURN_IF_ERROR(xla::mlir_gpu::LowerLHLOToGPU(module.get(), options));
|
||||
}
|
||||
|
||||
#if !defined(TENSORFLOW_USE_ROCM) && !defined(GOOGLE_CUDA)
|
||||
return InternalError(
|
||||
"Neither TENSORFLOW_USE_ROCM nor GOOGLE_CUDA are defined."
|
||||
" Did you specify either --config=rocm or --config=cuda ?");
|
||||
#endif
|
||||
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
TF_RETURN_IF_ERROR(xla::mlir_gpu::LowerKernelBodiesToROCDL(module.get()));
|
||||
#elif GOOGLE_CUDA
|
||||
TF_RETURN_IF_ERROR(xla::mlir_gpu::LowerKernelBodiesToNVVM(module.get()));
|
||||
#endif
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
PropagateTensorFlowABIKnowledgeToKernel(module.get(), same_shape));
|
||||
|
||||
mlir::OwningModuleRef kernel_module =
|
||||
xla::mlir_gpu::ExtractKernelModule(*module).ValueOrDie();
|
||||
|
||||
llvm::LLVMContext llvmContext;
|
||||
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
auto llvmModule = mlir::translateModuleToROCDLIR(*kernel_module, llvmContext);
|
||||
if (!llvmModule) {
|
||||
return InternalError("Could not translate MLIR module to ROCDL IR");
|
||||
}
|
||||
|
||||
llvmModule->setModuleIdentifier("acme");
|
||||
|
||||
xla::HloModuleConfig config;
|
||||
config.set_debug_options(xla::GetDebugOptionsFromFlags());
|
||||
|
||||
int gpu_version = compute_capability.first;
|
||||
std::string libdevice_dir = tensorflow::RocdlRoot();
|
||||
|
||||
return xla::gpu::amdgpu::CompileToHsaco(llvmModule.get(), gpu_version, config,
|
||||
libdevice_dir);
|
||||
#elif GOOGLE_CUDA
|
||||
auto llvmModule = mlir::translateModuleToNVVMIR(*kernel_module, llvmContext);
|
||||
if (!llvmModule) {
|
||||
return InternalError("Could not translate MLIR module to NVVM");
|
||||
@ -296,12 +331,8 @@ StatusOr<std::vector<uint8_t>> tensorflow::kernel_gen::GenerateCubinForTfCode(
|
||||
config, libdevice_dir, enable_fusion));
|
||||
VLOG(1) << ptx;
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
return tensorflow::se::CompileGpuAsm(
|
||||
std::get<0>(compute_capability), std::get<1>(compute_capability),
|
||||
ptx.c_str(), xla::gpu::PtxOptsFromConfig(config));
|
||||
#else
|
||||
return InternalError(
|
||||
"GOOGLE_CUDA not defined. Did you specify --config=cuda ?");
|
||||
#endif
|
||||
}
|
||||
|
@ -31,7 +31,7 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace kernel_gen {
|
||||
xla::StatusOr<std::vector<uint8_t>> GenerateCubinForTfCode(
|
||||
xla::StatusOr<std::vector<uint8_t>> GenerateGpuBinaryForTfCode(
|
||||
llvm::StringRef tf_code,
|
||||
std::pair<int32_t, int32_t> compute_capability = {7, 5},
|
||||
llvm::ArrayRef<uint32_t> tile_sizes = {16, 64},
|
||||
|
@ -53,8 +53,12 @@ int main(int argc, char** argv) {
|
||||
tensorflow::InitMlir y(&argc, &argv);
|
||||
llvm::cl::ParseCommandLineOptions(argc, argv, "TF op GPU kernel generator\n");
|
||||
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
std::pair<int32_t, int32_t> compute_capability(architecture, 0);
|
||||
#else
|
||||
std::pair<int32_t, int32_t> compute_capability(architecture / 10,
|
||||
architecture % 10);
|
||||
#endif
|
||||
|
||||
std::string tf_code;
|
||||
auto read_status = tensorflow::ReadFileToString(tensorflow::Env::Default(),
|
||||
@ -64,7 +68,7 @@ int main(int argc, char** argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto cubin = tensorflow::kernel_gen::GenerateCubinForTfCode(
|
||||
auto cubin = tensorflow::kernel_gen::GenerateGpuBinaryForTfCode(
|
||||
tf_code, compute_capability, tile_sizes, same_shape, unroll_factors);
|
||||
|
||||
if (!cubin.ok()) {
|
||||
|
@ -51,12 +51,14 @@ string HloModuleConfig::compilation_cache_key() const {
|
||||
string key = absl::StrCat("profiling=", hlo_profiling_enabled());
|
||||
StrAppend(&key, "::(");
|
||||
std::vector<string> params;
|
||||
for (const ShapeLayout& param_layout :
|
||||
entry_computation_layout_->parameter_layouts()) {
|
||||
params.push_back(param_layout.shape().DebugString());
|
||||
if (entry_computation_layout_.has_value()) {
|
||||
for (const ShapeLayout& param_layout :
|
||||
entry_computation_layout_->parameter_layouts()) {
|
||||
params.push_back(param_layout.shape().DebugString());
|
||||
}
|
||||
StrAppend(&key, absl::StrJoin(params, ", "), ") => ",
|
||||
entry_computation_layout_->result_shape().SerializeAsString());
|
||||
}
|
||||
StrAppend(&key, absl::StrJoin(params, ", "), ") => ",
|
||||
entry_computation_layout_->result_shape().SerializeAsString());
|
||||
if (seed() != 0) {
|
||||
// TODO(b/32083678): force recompilation to reset global state.
|
||||
static std::atomic<int> counter{0};
|
||||
|
Loading…
x
Reference in New Issue
Block a user