From d55c3f96eef62c61a360f709f452ed112b2c9b3d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 6 Oct 2020 03:13:14 -0700 Subject: [PATCH] [MLIR][KernelGen] Compile for multiple NVIDIA GPU architectures simultaneously For every architecture, compile the kernel module to ptx and to asm. The resulting cubins are then combined into one fatbin using the fatbinary tool. This change only affects the `tf_to_kernel` tool. PiperOrigin-RevId: 335604606 Change-Id: I880871ee1086ba7eace6d13dc13d5d9d352e6042 --- tensorflow/compiler/mlir/runlit.cfg.py | 4 +- .../compiler/mlir/tools/kernel_gen/BUILD | 5 +- .../mlir/tools/kernel_gen/kernel_creator.cc | 12 +- .../mlir/tools/kernel_gen/kernel_creator.h | 5 +- .../tests/tf_to_gpu_binary/tanh.mlir | 3 +- .../tools/kernel_gen/tests/tf_to_kernel/BUILD | 17 +++ .../kernel_gen/tests/tf_to_kernel/tanh.mlir | 6 + .../mlir/tools/kernel_gen/tf_to_gpu_binary.cc | 2 +- .../mlir/tools/kernel_gen/tf_to_kernel.cc | 11 +- .../mlir/tools/kernel_gen/transforms/BUILD | 1 + .../transforms/gpu_kernel_to_blob_pass.cc | 76 +++++++++--- .../mlir/tools/kernel_gen/transforms/passes.h | 3 +- .../tools/kernel_gen/transforms/passes.td | 5 +- .../kernels/mlir_generated/build_defs.bzl | 3 - tensorflow/stream_executor/cuda/BUILD | 4 + tensorflow/stream_executor/gpu/BUILD | 1 + .../stream_executor/gpu/asm_compiler.cc | 112 +++++++++++++++--- tensorflow/stream_executor/gpu/asm_compiler.h | 10 ++ 18 files changed, 227 insertions(+), 53 deletions(-) create mode 100644 tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_kernel/BUILD create mode 100644 tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_kernel/tanh.mlir diff --git a/tensorflow/compiler/mlir/runlit.cfg.py b/tensorflow/compiler/mlir/runlit.cfg.py index e403a75d3b9..17410b4e5b2 100644 --- a/tensorflow/compiler/mlir/runlit.cfg.py +++ b/tensorflow/compiler/mlir/runlit.cfg.py @@ -74,8 +74,8 @@ tool_names = [ 'tf_tfjs_translate', 'flatbuffer_to_string', 'flatbuffer_translate', 'tf-mlir-translate', 'mlir-tflite-runner', 'tfcompile', 'json_to_flatbuffer', 'xla-gpu-opt', 'xla-mlir-gpu-opt', 'xla-opt', - 'hlo_to_llvm_ir', 'kernel-gen-opt', 'tf_to_gpu_binary', 'xla-thunks-opt', - 'tfjs-opt' + 'hlo_to_llvm_ir', 'kernel-gen-opt', 'tf_to_kernel', 'tf_to_gpu_binary', + 'xla-thunks-opt', 'tfjs-opt' ] tools = [ToolSubst(s, unresolved='ignore') for s in tool_names] llvm_config.add_tool_substitutions(tools, tool_dirs) diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD index e2646dd712c..ec00554cd54 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD @@ -105,7 +105,10 @@ tf_cc_binary( tf_cc_binary( name = "tf_to_kernel", srcs = ["tf_to_kernel.cc"], - visibility = ["//tensorflow/core/kernels/mlir_generated:__pkg__"], + visibility = [ + "//tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_kernel:__pkg__", + "//tensorflow/core/kernels/mlir_generated:__pkg__", + ], deps = [ ":kernel_creator", "//tensorflow/compiler/mlir:init_mlir", diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc index 48696f6e8b0..c3b16721f56 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc @@ -174,7 +174,8 @@ Status LowerTFtoGPU(mlir::ModuleOp module, bool gpu_binary_only, Status LowerGPUToLLVM(mlir::ModuleOp module, bool gpu_binary_only, llvm::ArrayRef same_shape, llvm::StringRef gpu_binary_attr_name, - int32_t architecture) { + llvm::ArrayRef architectures, + bool generate_fatbin) { mlir::PassManager pm(module.getContext()); applyTensorflowAndCLOptions(pm); @@ -187,7 +188,7 @@ Status LowerGPUToLLVM(mlir::ModuleOp module, bool gpu_binary_only, } kernel_pm.addPass(mlir::createStripDebugInfoPass()); kernel_pm.addPass(mlir::kernel_gen::transforms::CreateGpuKernelToBlobPass( - gpu_binary_attr_name, architecture)); + gpu_binary_attr_name, architectures, generate_fatbin)); if (!gpu_binary_only) { pm.addPass(mlir::kernel_gen::transforms::CreateTFKernelToLLVMPass()); @@ -202,9 +203,9 @@ Status LowerGPUToLLVM(mlir::ModuleOp module, bool gpu_binary_only, StatusOr GenerateKernelForTfCode( mlir::MLIRContext& context, llvm::StringRef tf_code, bool gpu_binary_only, - int32_t architecture, llvm::ArrayRef tile_sizes, + llvm::ArrayRef architectures, llvm::ArrayRef tile_sizes, llvm::ArrayRef same_shape, - llvm::ArrayRef unroll_factors) { + llvm::ArrayRef unroll_factors, bool generate_fatbin) { mlir::RegisterAllTensorFlowDialects(context.getDialectRegistry()); mlir::OwningModuleRef module = mlir::parseSourceString(tf_code, &context); TF_RETURN_IF_ERROR( @@ -221,7 +222,8 @@ StatusOr GenerateKernelForTfCode( TF_RETURN_IF_ERROR(xla::mlir_gpu::LowerKernelBodiesToNVVM(module.get())); #endif TF_RETURN_IF_ERROR(LowerGPUToLLVM(module.get(), gpu_binary_only, same_shape, - kGpuBinaryAttrName, architecture)); + kGpuBinaryAttrName, architectures, + generate_fatbin)); return module; } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h index b168ec815de..0a74a8a3d5a 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h +++ b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h @@ -38,9 +38,10 @@ namespace kernel_gen { // false, lowers the host side to LLVM Dialect. xla::StatusOr GenerateKernelForTfCode( mlir::MLIRContext& context, llvm::StringRef tf_code, bool gpu_binary_only, - int32_t architecture = 75, llvm::ArrayRef tile_sizes = {16, 64}, + llvm::ArrayRef architectures = {75}, + llvm::ArrayRef tile_sizes = {16, 64}, llvm::ArrayRef same_shape = {}, - llvm::ArrayRef unroll_factors = {}); + llvm::ArrayRef unroll_factors = {}, bool generate_fatbin = true); // Extracts gpu_binary from the converted module. xla::StatusOr ExtractGpuBinary(mlir::ModuleOp module); diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary/tanh.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary/tanh.mlir index e596c338b14..de9f4aee1cb 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary/tanh.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary/tanh.mlir @@ -1,6 +1,5 @@ // RUN: tf_to_gpu_binary --input=%s --output=%t --same_shape=0,1 --unroll_factors=4 --tile_sizes=256 --arch=70 func @tanh(%arg0: tensor) -> tensor { - %0 = "tf.Tanh"(%arg0) { } - : (tensor) -> tensor + %0 = "tf.Tanh"(%arg0) : (tensor) -> tensor return %0 : tensor } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_kernel/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_kernel/BUILD new file mode 100644 index 00000000000..24e288c246c --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_kernel/BUILD @@ -0,0 +1,17 @@ +load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") + +package(licenses = ["notice"]) + +glob_lit_tests( + data = [ + "//tensorflow/compiler/mlir/tools/kernel_gen:tf_to_kernel", + "@llvm-project//mlir:run_lit.sh", + ], + default_tags = [ + # We need access to the CUDA SDK. + "gpu", + "no_rocm", + ], + driver = "//tensorflow/compiler/mlir:run_lit.sh", + test_file_exts = ["mlir"], +) diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_kernel/tanh.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_kernel/tanh.mlir new file mode 100644 index 00000000000..d5d1b87bb67 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_kernel/tanh.mlir @@ -0,0 +1,6 @@ +// RUN: tf_to_kernel --input=%s --output=%t --same_shape=0,1 --unroll_factors=4 --tile_sizes=256 --arch=70,75 + +func @tanh(%arg: tensor<*xf32>) -> tensor<*xf32> { + %0 = "tf.Tanh"(%arg) : (tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_gpu_binary.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_gpu_binary.cc index c7cb92404f5..cbd97e258b7 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_gpu_binary.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_gpu_binary.cc @@ -48,7 +48,7 @@ xla::Status Run(llvm::StringRef input_file, llvm::StringRef output_file, mlir::OwningModuleRef module, GenerateKernelForTfCode(context, tf_code, /*gpu_binary_only=*/true, architecture, tile_sizes, same_shape, - unroll_factors)); + unroll_factors, /*generate_fatbin=*/false)); // Extract gpu_binary. TF_ASSIGN_OR_RETURN(std::string gpu_binary, ExtractGpuBinary(*module)); diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_kernel.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_kernel.cc index e62fa47cea9..d2d71a28ff3 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_kernel.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_kernel.cc @@ -95,7 +95,8 @@ xla::StatusOr EmitToBinary(mlir::ModuleOp module) { } xla::Status Run(llvm::StringRef input_file, llvm::StringRef output_file, - int32_t architecture, llvm::ArrayRef tile_sizes, + llvm::ArrayRef architectures, + llvm::ArrayRef tile_sizes, llvm::ArrayRef same_shape, llvm::ArrayRef unroll_factors) { // Read TF code. @@ -107,7 +108,7 @@ xla::Status Run(llvm::StringRef input_file, llvm::StringRef output_file, TF_ASSIGN_OR_RETURN( mlir::OwningModuleRef module, GenerateKernelForTfCode(context, tf_code, /*gpu_binary_only=*/false, - architecture, tile_sizes, same_shape, + architectures, tile_sizes, same_shape, unroll_factors)); // Get binary. TF_ASSIGN_OR_RETURN(std::string binary, EmitToBinary(*module)); @@ -129,8 +130,8 @@ int main(int argc, char** argv) { llvm::cl::opt output_file( "output", llvm::cl::desc("output file"), llvm::cl::value_desc("filename"), llvm::cl::init("foo.bin")); - llvm::cl::list architecture( - "arch", llvm::cl::desc("target architecture (e.g. 50 for sm_50)"), + llvm::cl::list architectures( + "arch", llvm::cl::desc("target architectures (e.g. 50 for sm_50)"), llvm::cl::OneOrMore, llvm::cl::CommaSeparated); llvm::cl::list tile_sizes( "tile_sizes", llvm::cl::desc("tile sizes to use"), llvm::cl::ZeroOrMore, @@ -151,7 +152,7 @@ int main(int argc, char** argv) { llvm::cl::ParseCommandLineOptions(argc, argv, "TF op GPU kernel generator\n"); auto status = - tensorflow::kernel_gen::Run(input_file, output_file, architecture.front(), + tensorflow::kernel_gen::Run(input_file, output_file, architectures, tile_sizes, same_shape, unroll_factors); if (!status.ok()) { LOG(ERROR) << status; diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD index d4110b466c9..caa665b2971 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD @@ -117,6 +117,7 @@ cc_library( "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", + "@llvm-project//llvm:TransformUtils", "//tensorflow/compiler/mlir/hlo", "//tensorflow/compiler/mlir/hlo:hlo_legalize_to_lhlo", "//tensorflow/compiler/mlir/hlo:lhlo", diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc index dda0e242b2e..83d4f43ade5 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "llvm/Transforms/Utils/Cloning.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/Target/NVVMIR.h" // from @llvm-project #include "mlir/Target/ROCDLIR.h" // from @llvm-project @@ -49,9 +50,12 @@ using xla::InternalError; class GpuKernelToBlobPass : public GpuKernelToBlobPassBase { public: - GpuKernelToBlobPass(mlir::StringRef blob_annotation, int32_t arch) { + GpuKernelToBlobPass(mlir::StringRef blob_annotation, + llvm::ArrayRef architectures, + bool generate_fatbin) { blob_annotation_ = blob_annotation.str(); - arch_ = arch; + architectures_ = architectures; + generate_fatbin_ = generate_fatbin; } void runOnOperation() override { @@ -69,7 +73,17 @@ class GpuKernelToBlobPass xla::StatusOr> GetGpuBinaryBlob( mlir::gpu::GPUModuleOp gpu_module) { + if (architectures_.empty()) { + return InternalError("Expected at least one GPU architecture."); + } + if (!generate_fatbin_ && architectures_.size() > 1) { + return InternalError( + "Can only generate machine code for more than one architecture as a " + "fatbin."); + } + llvm::LLVMContext llvmContext; + #if TENSORFLOW_USE_ROCM auto llvmModule = mlir::translateModuleToROCDLIR(gpu_module, llvmContext); if (!llvmModule) { @@ -81,9 +95,14 @@ class GpuKernelToBlobPass xla::HloModuleConfig config; config.set_debug_options(xla::GetDebugOptionsFromFlags()); - std::string libdevice_dir = tensorflow::RocdlRoot(); + // TODO(b/169066682): Support fatbin on ROCm. + if (generate_fatbin_) { + return InternalError("Fatbins are not yet supported for ROCm."); + } - return xla::gpu::amdgpu::CompileToHsaco(llvmModule.get(), arch_, config, + uint32_t arch = architectures_.front(); + std::string libdevice_dir = tensorflow::RocdlRoot(); + return xla::gpu::amdgpu::CompileToHsaco(llvmModule.get(), arch, config, libdevice_dir); #elif GOOGLE_CUDA @@ -102,19 +121,42 @@ class GpuKernelToBlobPass target->Options.AllowFPOpFusion = llvm::FPOpFusion::FPOpFusionMode::Fast; }; - int32_t cc_major = arch_ / 10; - int32_t cc_minor = arch_ % 10; + // Compile and collect requested cubin and PTX images. + std::vector images; TF_ASSIGN_OR_RETURN(std::string libdevice_dir, GetLibdeviceDir(config)); - TF_ASSIGN_OR_RETURN( - std::string ptx, - xla::gpu::nvptx::CompileToPtx(llvmModule.get(), - std::make_pair(cc_major, cc_minor), - config, libdevice_dir, enable_fusion)); - VLOG(1) << ptx; + auto gpu_asm_opts = xla::gpu::PtxOptsFromConfig(config); + for (uint32_t arch : architectures_) { + int32_t cc_major = arch / 10; + int32_t cc_minor = arch % 10; + // Module may be changed by CompileToPtx. + auto llvm_module_copy = llvm::CloneModule(*llvmModule); + TF_ASSIGN_OR_RETURN( + std::string ptx, + xla::gpu::nvptx::CompileToPtx(llvm_module_copy.get(), + std::make_pair(cc_major, cc_minor), + config, libdevice_dir, enable_fusion)); + // TODO(b/169066682): If compute_XX profile, collect PTX image here. + VLOG(1) << ptx; + TF_ASSIGN_OR_RETURN(std::vector gpu_asm, + tensorflow::se::CompileGpuAsm( + cc_major, cc_minor, ptx.c_str(), gpu_asm_opts)); - return tensorflow::se::CompileGpuAsm(cc_major, cc_minor, ptx.c_str(), - xla::gpu::PtxOptsFromConfig(config)); + if (!generate_fatbin_) { + // Skip fatbin generation and return the first and only GPU machine + // code. + return gpu_asm; + } + + // Collect cubin image. + images.push_back({absl::StrCat("sm_", arch), std::move(gpu_asm)}); + } + + // TODO(b/169870789): Revisit the use of fatbins. + // Bundle cubin and PTX images into a single fatbin. + return tensorflow::se::BundleGpuAsm(images, + gpu_asm_opts.preferred_cuda_dir); #endif + return InternalError( "Neither TENSORFLOW_USE_ROCM nor GOOGLE_CUDA are defined." " Did you specify either --config=rocm or --config=cuda ?"); @@ -141,8 +183,10 @@ class GpuKernelToBlobPass } // namespace std::unique_ptr> CreateGpuKernelToBlobPass( - mlir::StringRef blob_annotation, int32_t architecture) { - return std::make_unique(blob_annotation, architecture); + mlir::StringRef blob_annotation, ArrayRef architectures, + bool generate_fatbin) { + return std::make_unique(blob_annotation, architectures, + generate_fatbin); } } // namespace transforms diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h index 2ef863a394c..43e464645a2 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h @@ -61,7 +61,8 @@ CreatePropagateTensorFlowABIKnowledgePass( // Pass to annotate GPU Module with its PTX. std::unique_ptr> CreateGpuKernelToBlobPass( - mlir::StringRef blob_annotation = "", int32_t architecture = 0); + mlir::StringRef blob_annotation = "", ArrayRef architectures = {}, + bool generate_fatbin = true); // Pass to unfuse batch norm. std::unique_ptr CreateUnfuseBatchNormPass(); diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td index 5bdd466732b..e84971bbf69 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td @@ -53,7 +53,10 @@ def GpuKernelToBlobPass : Pass<"gpu-kernel-to-blob", "gpu::GPUModuleOp"> { let options = [ Option<"blob_annotation_", "blob-annotation", "std::string", /*default=*/"", "Blob attribute name">, - Option<"arch_", "arch", "int32_t", /*default=*/"0", "GPU architecture">, + ListOption<"architectures_", "arch", "uint32_t", "GPU architectures">, + Option<"generate_fatbin_", "generate-fatbin", "bool", /*default=*/"true", + "Bundle machine code for the different architectures in one " + "fatbin.">, ]; let constructor = "transforms::CreateGpuKernelToBlobPass()"; } diff --git a/tensorflow/core/kernels/mlir_generated/build_defs.bzl b/tensorflow/core/kernels/mlir_generated/build_defs.bzl index b92b695e475..1dbecdca90e 100644 --- a/tensorflow/core/kernels/mlir_generated/build_defs.bzl +++ b/tensorflow/core/kernels/mlir_generated/build_defs.bzl @@ -306,9 +306,6 @@ def _gen_unranked_kernel_fatbin_impl(ctx): archs_trimmed.append(arch[3:]) arch_flag = ",".join(archs_trimmed) - # TODO(b/169066682): Generate Fatbin when lowering GPU module. - arch_flag = "75" - filename = "%s.a" % (name) gpu_bin = ctx.outputs.output ctx.actions.run( diff --git a/tensorflow/stream_executor/cuda/BUILD b/tensorflow/stream_executor/cuda/BUILD index ea65d7aee5c..124e7f9d482 100644 --- a/tensorflow/stream_executor/cuda/BUILD +++ b/tensorflow/stream_executor/cuda/BUILD @@ -106,6 +106,10 @@ cc_library( # an intermediate target. cc_library(name = "ptxas_wrapper") +# Buildozer can not remove dependencies inside select guards, so we have to use +# an intermediate target. +cc_library(name = "fatbinary_wrapper") + cc_library( name = "cuda_driver", srcs = if_cuda_is_configured(["cuda_driver.cc"]), diff --git a/tensorflow/stream_executor/gpu/BUILD b/tensorflow/stream_executor/gpu/BUILD index a2696bd0088..75a09a4fa97 100644 --- a/tensorflow/stream_executor/gpu/BUILD +++ b/tensorflow/stream_executor/gpu/BUILD @@ -251,6 +251,7 @@ cc_library( ]) + if_cuda_is_configured([ "//tensorflow/stream_executor/cuda:cuda_driver", "//tensorflow/stream_executor/cuda:ptxas_wrapper", + "//tensorflow/stream_executor/cuda:fatbinary_wrapper", ]), ) diff --git a/tensorflow/stream_executor/gpu/asm_compiler.cc b/tensorflow/stream_executor/gpu/asm_compiler.cc index 0f6fd4de910..53f76503f2a 100644 --- a/tensorflow/stream_executor/gpu/asm_compiler.cc +++ b/tensorflow/stream_executor/gpu/asm_compiler.cc @@ -140,34 +140,44 @@ port::StatusOr> CompileGpuAsm(int device_ordinal, return CompileGpuAsm(cc_major, cc_minor, ptx_contents, options); } -port::StatusOr> CompileGpuAsm(int cc_major, int cc_minor, - const char* ptx_contents, - GpuAsmOpts options) { - std::string ptxas_path; - auto env = tensorflow::Env::Default(); - std::string ptxas_binary_name = "ptxas"; +static std::string findCudaExecutable(const std::string binary_name, + const std::string preferred_cuda_dir) { #if defined(PLATFORM_WINDOWS) - ptxas_binary_name += ".exe"; + const std::string binary_filename = binary_name + ".exe"; +#else + const std::string& binary_filename = binary_name; #endif + // Search in cuda root candidates. + auto env = tensorflow::Env::Default(); + std::string binary_path; for (const std::string& cuda_root : - tensorflow::CandidateCudaRoots(options.preferred_cuda_dir)) { - ptxas_path = tensorflow::io::JoinPath(cuda_root, "bin", ptxas_binary_name); - VLOG(2) << "Looking for ptxas at " << ptxas_path; - if (env->FileExists(ptxas_path).ok()) { + tensorflow::CandidateCudaRoots(preferred_cuda_dir)) { + binary_path = tensorflow::io::JoinPath(cuda_root, "bin", binary_filename); + VLOG(2) << "Looking for " << binary_filename << " at " << binary_path; + if (env->FileExists(binary_path).ok()) { break; } } - if (!env->FileExists(ptxas_path).ok()) { + if (!env->FileExists(binary_path).ok()) { // Rely on subprocess invocation to find the correct binary. - ptxas_path = ptxas_binary_name; + binary_path = binary_filename; } - VLOG(2) << "Using ptxas at " << ptxas_path; + VLOG(2) << "Using " << binary_filename << " at " << binary_path; + return binary_path; +} + +port::StatusOr> CompileGpuAsm(int cc_major, int cc_minor, + const char* ptx_contents, + GpuAsmOpts options) { + std::string ptxas_path = + findCudaExecutable("ptxas", options.preferred_cuda_dir); WarnIfBadPtxasVersion(ptxas_path); // Write ptx into a temporary file. std::string ptx_path; + auto env = tensorflow::Env::Default(); if (!env->LocalTempFilename(&ptx_path)) { return port::InternalError("couldn't get temp PTX file name"); } @@ -232,4 +242,78 @@ port::StatusOr> CompileGpuAsm(int cc_major, int cc_minor, return cubin_vector; } +port::StatusOr> BundleGpuAsm( + std::vector images, const std::string preferred_cuda_dir) { + std::string fatbinary_path = + findCudaExecutable("fatbinary", preferred_cuda_dir); + + // Write images to temporary files. + std::vector image_paths; + auto env = tensorflow::Env::Default(); + for (const CubinOrPTXImage& img : images) { + std::string img_path; + if (!env->LocalTempFilename(&img_path)) { + return port::InternalError( + "Could not get temporary filenames for images."); + } + TF_RETURN_IF_ERROR(tensorflow::WriteStringToFile( + env, img_path, std::string(img.bytes.begin(), img.bytes.end()))); + VLOG(2) << "image written to " << img_path; + image_paths.push_back(std::move(img_path)); + } + auto image_files_cleaner = tensorflow::gtl::MakeCleanup([&image_paths] { + for (const auto& path : image_paths) { + TF_CHECK_OK(tensorflow::Env::Default()->DeleteFile(path)); + } + }); + + // Prepare temorary result file. + std::string result_path; + if (!env->LocalTempFilename(&result_path)) { + return port::InternalError( + "Could not get temporary filename for fatbin result."); + } + auto result_file_cleaner = tensorflow::gtl::MakeCleanup([&result_path] { + // This file may never be created, so the failure to delete it should not + // propagate to TF. + tensorflow::Env::Default()->DeleteFile(result_path).IgnoreError(); + }); + + // Invoke fatbinary and collect its output. + tensorflow::SubProcess fatbinary; + std::vector fatbinary_args = { + fatbinary_path, "--64", "--cmdline=--compile-only", + "--link", "--compress-all", absl::StrCat("--create=", result_path)}; + assert(images.size() == image_paths.size()); + for (int i = 0; i < images.size(); i++) { + fatbinary_args.push_back(absl::StrFormat( + "--image=profile=%s,file=%s", images[i].profile, image_paths[i])); + } + if (VLOG_IS_ON(3)) { + VLOG(3) << absl::StrJoin(fatbinary_args, " "); + } + fatbinary.SetProgram(fatbinary_path, fatbinary_args); + fatbinary.SetChannelAction(tensorflow::CHAN_STDERR, tensorflow::ACTION_PIPE); + if (!fatbinary.Start()) { + return port::InternalError("Failed to launch fatbinary."); + } + std::string stderr_output; + int exit_status = fatbinary.Communicate( + /*stdin_input=*/nullptr, /*stdout_output=*/nullptr, &stderr_output); + if (exit_status != 0) { + return port::InternalError(absl::StrFormat( + "fatbinary exited with non-zero error code %d, output: %s", exit_status, + stderr_output)); + } + if (!stderr_output.empty()) { + VLOG(2) << stderr_output; + } + + // Read in the result and return it as a byte vector. + std::string result_blob; + TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(tensorflow::Env::Default(), + result_path, &result_blob)); + return std::vector(result_blob.begin(), result_blob.end()); +} + } // namespace stream_executor diff --git a/tensorflow/stream_executor/gpu/asm_compiler.h b/tensorflow/stream_executor/gpu/asm_compiler.h index e5f67a71242..513ac6ca867 100644 --- a/tensorflow/stream_executor/gpu/asm_compiler.h +++ b/tensorflow/stream_executor/gpu/asm_compiler.h @@ -52,6 +52,16 @@ port::StatusOr> CompileGpuAsm(int cc_major, int cc_minor, port::StatusOr> CompileGpuAsmOrGetCached( int device_ordinal, const char* ptx, GpuAsmOpts compilation_options); +struct CubinOrPTXImage { + std::string profile; + std::vector bytes; +}; + +// Bundles the GPU machine code (cubins) and PTX if requested and returns the +// resulting binary (i.e. a fatbin) as a byte array. +port::StatusOr> BundleGpuAsm( + std::vector images, const std::string preferred_cuda_dir); + } // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_GPU_ASM_COMPILER_H_