[MLIR][KernelGen] Compile for multiple NVIDIA GPU architectures simultaneously
For every architecture, compile the kernel module to ptx and to asm. The resulting cubins are then combined into one fatbin using the fatbinary tool. This change only affects the `tf_to_kernel` tool. PiperOrigin-RevId: 335604606 Change-Id: I880871ee1086ba7eace6d13dc13d5d9d352e6042
This commit is contained in:
parent
e0bd971a0a
commit
d55c3f96ee
@ -74,8 +74,8 @@ tool_names = [
|
||||
'tf_tfjs_translate', 'flatbuffer_to_string', 'flatbuffer_translate',
|
||||
'tf-mlir-translate', 'mlir-tflite-runner', 'tfcompile',
|
||||
'json_to_flatbuffer', 'xla-gpu-opt', 'xla-mlir-gpu-opt', 'xla-opt',
|
||||
'hlo_to_llvm_ir', 'kernel-gen-opt', 'tf_to_gpu_binary', 'xla-thunks-opt',
|
||||
'tfjs-opt'
|
||||
'hlo_to_llvm_ir', 'kernel-gen-opt', 'tf_to_kernel', 'tf_to_gpu_binary',
|
||||
'xla-thunks-opt', 'tfjs-opt'
|
||||
]
|
||||
tools = [ToolSubst(s, unresolved='ignore') for s in tool_names]
|
||||
llvm_config.add_tool_substitutions(tools, tool_dirs)
|
||||
|
||||
@ -105,7 +105,10 @@ tf_cc_binary(
|
||||
tf_cc_binary(
|
||||
name = "tf_to_kernel",
|
||||
srcs = ["tf_to_kernel.cc"],
|
||||
visibility = ["//tensorflow/core/kernels/mlir_generated:__pkg__"],
|
||||
visibility = [
|
||||
"//tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_kernel:__pkg__",
|
||||
"//tensorflow/core/kernels/mlir_generated:__pkg__",
|
||||
],
|
||||
deps = [
|
||||
":kernel_creator",
|
||||
"//tensorflow/compiler/mlir:init_mlir",
|
||||
|
||||
@ -174,7 +174,8 @@ Status LowerTFtoGPU(mlir::ModuleOp module, bool gpu_binary_only,
|
||||
Status LowerGPUToLLVM(mlir::ModuleOp module, bool gpu_binary_only,
|
||||
llvm::ArrayRef<uint32_t> same_shape,
|
||||
llvm::StringRef gpu_binary_attr_name,
|
||||
int32_t architecture) {
|
||||
llvm::ArrayRef<uint32_t> architectures,
|
||||
bool generate_fatbin) {
|
||||
mlir::PassManager pm(module.getContext());
|
||||
applyTensorflowAndCLOptions(pm);
|
||||
|
||||
@ -187,7 +188,7 @@ Status LowerGPUToLLVM(mlir::ModuleOp module, bool gpu_binary_only,
|
||||
}
|
||||
kernel_pm.addPass(mlir::createStripDebugInfoPass());
|
||||
kernel_pm.addPass(mlir::kernel_gen::transforms::CreateGpuKernelToBlobPass(
|
||||
gpu_binary_attr_name, architecture));
|
||||
gpu_binary_attr_name, architectures, generate_fatbin));
|
||||
|
||||
if (!gpu_binary_only) {
|
||||
pm.addPass(mlir::kernel_gen::transforms::CreateTFKernelToLLVMPass());
|
||||
@ -202,9 +203,9 @@ Status LowerGPUToLLVM(mlir::ModuleOp module, bool gpu_binary_only,
|
||||
|
||||
StatusOr<mlir::OwningModuleRef> GenerateKernelForTfCode(
|
||||
mlir::MLIRContext& context, llvm::StringRef tf_code, bool gpu_binary_only,
|
||||
int32_t architecture, llvm::ArrayRef<uint32_t> tile_sizes,
|
||||
llvm::ArrayRef<uint32_t> architectures, llvm::ArrayRef<uint32_t> tile_sizes,
|
||||
llvm::ArrayRef<uint32_t> same_shape,
|
||||
llvm::ArrayRef<uint32_t> unroll_factors) {
|
||||
llvm::ArrayRef<uint32_t> unroll_factors, bool generate_fatbin) {
|
||||
mlir::RegisterAllTensorFlowDialects(context.getDialectRegistry());
|
||||
mlir::OwningModuleRef module = mlir::parseSourceString(tf_code, &context);
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -221,7 +222,8 @@ StatusOr<mlir::OwningModuleRef> GenerateKernelForTfCode(
|
||||
TF_RETURN_IF_ERROR(xla::mlir_gpu::LowerKernelBodiesToNVVM(module.get()));
|
||||
#endif
|
||||
TF_RETURN_IF_ERROR(LowerGPUToLLVM(module.get(), gpu_binary_only, same_shape,
|
||||
kGpuBinaryAttrName, architecture));
|
||||
kGpuBinaryAttrName, architectures,
|
||||
generate_fatbin));
|
||||
return module;
|
||||
}
|
||||
|
||||
|
||||
@ -38,9 +38,10 @@ namespace kernel_gen {
|
||||
// false, lowers the host side to LLVM Dialect.
|
||||
xla::StatusOr<mlir::OwningModuleRef> GenerateKernelForTfCode(
|
||||
mlir::MLIRContext& context, llvm::StringRef tf_code, bool gpu_binary_only,
|
||||
int32_t architecture = 75, llvm::ArrayRef<uint32_t> tile_sizes = {16, 64},
|
||||
llvm::ArrayRef<uint32_t> architectures = {75},
|
||||
llvm::ArrayRef<uint32_t> tile_sizes = {16, 64},
|
||||
llvm::ArrayRef<uint32_t> same_shape = {},
|
||||
llvm::ArrayRef<uint32_t> unroll_factors = {});
|
||||
llvm::ArrayRef<uint32_t> unroll_factors = {}, bool generate_fatbin = true);
|
||||
|
||||
// Extracts gpu_binary from the converted module.
|
||||
xla::StatusOr<std::string> ExtractGpuBinary(mlir::ModuleOp module);
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
// RUN: tf_to_gpu_binary --input=%s --output=%t --same_shape=0,1 --unroll_factors=4 --tile_sizes=256 --arch=70
|
||||
func @tanh(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%0 = "tf.Tanh"(%arg0) { }
|
||||
: (tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tf.Tanh"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
@ -0,0 +1,17 @@
|
||||
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
|
||||
|
||||
package(licenses = ["notice"])
|
||||
|
||||
glob_lit_tests(
|
||||
data = [
|
||||
"//tensorflow/compiler/mlir/tools/kernel_gen:tf_to_kernel",
|
||||
"@llvm-project//mlir:run_lit.sh",
|
||||
],
|
||||
default_tags = [
|
||||
# We need access to the CUDA SDK.
|
||||
"gpu",
|
||||
"no_rocm",
|
||||
],
|
||||
driver = "//tensorflow/compiler/mlir:run_lit.sh",
|
||||
test_file_exts = ["mlir"],
|
||||
)
|
||||
@ -0,0 +1,6 @@
|
||||
// RUN: tf_to_kernel --input=%s --output=%t --same_shape=0,1 --unroll_factors=4 --tile_sizes=256 --arch=70,75
|
||||
|
||||
func @tanh(%arg: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = "tf.Tanh"(%arg) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
@ -48,7 +48,7 @@ xla::Status Run(llvm::StringRef input_file, llvm::StringRef output_file,
|
||||
mlir::OwningModuleRef module,
|
||||
GenerateKernelForTfCode(context, tf_code, /*gpu_binary_only=*/true,
|
||||
architecture, tile_sizes, same_shape,
|
||||
unroll_factors));
|
||||
unroll_factors, /*generate_fatbin=*/false));
|
||||
// Extract gpu_binary.
|
||||
TF_ASSIGN_OR_RETURN(std::string gpu_binary, ExtractGpuBinary(*module));
|
||||
|
||||
|
||||
@ -95,7 +95,8 @@ xla::StatusOr<std::string> EmitToBinary(mlir::ModuleOp module) {
|
||||
}
|
||||
|
||||
xla::Status Run(llvm::StringRef input_file, llvm::StringRef output_file,
|
||||
int32_t architecture, llvm::ArrayRef<uint32_t> tile_sizes,
|
||||
llvm::ArrayRef<uint32_t> architectures,
|
||||
llvm::ArrayRef<uint32_t> tile_sizes,
|
||||
llvm::ArrayRef<uint32_t> same_shape,
|
||||
llvm::ArrayRef<uint32_t> unroll_factors) {
|
||||
// Read TF code.
|
||||
@ -107,7 +108,7 @@ xla::Status Run(llvm::StringRef input_file, llvm::StringRef output_file,
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
mlir::OwningModuleRef module,
|
||||
GenerateKernelForTfCode(context, tf_code, /*gpu_binary_only=*/false,
|
||||
architecture, tile_sizes, same_shape,
|
||||
architectures, tile_sizes, same_shape,
|
||||
unroll_factors));
|
||||
// Get binary.
|
||||
TF_ASSIGN_OR_RETURN(std::string binary, EmitToBinary(*module));
|
||||
@ -129,8 +130,8 @@ int main(int argc, char** argv) {
|
||||
llvm::cl::opt<std::string> output_file(
|
||||
"output", llvm::cl::desc("output file"), llvm::cl::value_desc("filename"),
|
||||
llvm::cl::init("foo.bin"));
|
||||
llvm::cl::list<int32_t> architecture(
|
||||
"arch", llvm::cl::desc("target architecture (e.g. 50 for sm_50)"),
|
||||
llvm::cl::list<uint32_t> architectures(
|
||||
"arch", llvm::cl::desc("target architectures (e.g. 50 for sm_50)"),
|
||||
llvm::cl::OneOrMore, llvm::cl::CommaSeparated);
|
||||
llvm::cl::list<uint32_t> tile_sizes(
|
||||
"tile_sizes", llvm::cl::desc("tile sizes to use"), llvm::cl::ZeroOrMore,
|
||||
@ -151,7 +152,7 @@ int main(int argc, char** argv) {
|
||||
llvm::cl::ParseCommandLineOptions(argc, argv, "TF op GPU kernel generator\n");
|
||||
|
||||
auto status =
|
||||
tensorflow::kernel_gen::Run(input_file, output_file, architecture.front(),
|
||||
tensorflow::kernel_gen::Run(input_file, output_file, architectures,
|
||||
tile_sizes, same_shape, unroll_factors);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << status;
|
||||
|
||||
@ -117,6 +117,7 @@ cc_library(
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
"@llvm-project//llvm:TransformUtils",
|
||||
"//tensorflow/compiler/mlir/hlo",
|
||||
"//tensorflow/compiler/mlir/hlo:hlo_legalize_to_lhlo",
|
||||
"//tensorflow/compiler/mlir/hlo:lhlo",
|
||||
|
||||
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "llvm/Transforms/Utils/Cloning.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/Target/NVVMIR.h" // from @llvm-project
|
||||
#include "mlir/Target/ROCDLIR.h" // from @llvm-project
|
||||
@ -49,9 +50,12 @@ using xla::InternalError;
|
||||
class GpuKernelToBlobPass
|
||||
: public GpuKernelToBlobPassBase<GpuKernelToBlobPass> {
|
||||
public:
|
||||
GpuKernelToBlobPass(mlir::StringRef blob_annotation, int32_t arch) {
|
||||
GpuKernelToBlobPass(mlir::StringRef blob_annotation,
|
||||
llvm::ArrayRef<uint32_t> architectures,
|
||||
bool generate_fatbin) {
|
||||
blob_annotation_ = blob_annotation.str();
|
||||
arch_ = arch;
|
||||
architectures_ = architectures;
|
||||
generate_fatbin_ = generate_fatbin;
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
@ -69,7 +73,17 @@ class GpuKernelToBlobPass
|
||||
|
||||
xla::StatusOr<std::vector<uint8_t>> GetGpuBinaryBlob(
|
||||
mlir::gpu::GPUModuleOp gpu_module) {
|
||||
if (architectures_.empty()) {
|
||||
return InternalError("Expected at least one GPU architecture.");
|
||||
}
|
||||
if (!generate_fatbin_ && architectures_.size() > 1) {
|
||||
return InternalError(
|
||||
"Can only generate machine code for more than one architecture as a "
|
||||
"fatbin.");
|
||||
}
|
||||
|
||||
llvm::LLVMContext llvmContext;
|
||||
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
auto llvmModule = mlir::translateModuleToROCDLIR(gpu_module, llvmContext);
|
||||
if (!llvmModule) {
|
||||
@ -81,9 +95,14 @@ class GpuKernelToBlobPass
|
||||
xla::HloModuleConfig config;
|
||||
config.set_debug_options(xla::GetDebugOptionsFromFlags());
|
||||
|
||||
std::string libdevice_dir = tensorflow::RocdlRoot();
|
||||
// TODO(b/169066682): Support fatbin on ROCm.
|
||||
if (generate_fatbin_) {
|
||||
return InternalError("Fatbins are not yet supported for ROCm.");
|
||||
}
|
||||
|
||||
return xla::gpu::amdgpu::CompileToHsaco(llvmModule.get(), arch_, config,
|
||||
uint32_t arch = architectures_.front();
|
||||
std::string libdevice_dir = tensorflow::RocdlRoot();
|
||||
return xla::gpu::amdgpu::CompileToHsaco(llvmModule.get(), arch, config,
|
||||
libdevice_dir);
|
||||
|
||||
#elif GOOGLE_CUDA
|
||||
@ -102,19 +121,42 @@ class GpuKernelToBlobPass
|
||||
target->Options.AllowFPOpFusion = llvm::FPOpFusion::FPOpFusionMode::Fast;
|
||||
};
|
||||
|
||||
int32_t cc_major = arch_ / 10;
|
||||
int32_t cc_minor = arch_ % 10;
|
||||
// Compile and collect requested cubin and PTX images.
|
||||
std::vector<tensorflow::se::CubinOrPTXImage> images;
|
||||
TF_ASSIGN_OR_RETURN(std::string libdevice_dir, GetLibdeviceDir(config));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::string ptx,
|
||||
xla::gpu::nvptx::CompileToPtx(llvmModule.get(),
|
||||
std::make_pair(cc_major, cc_minor),
|
||||
config, libdevice_dir, enable_fusion));
|
||||
VLOG(1) << ptx;
|
||||
auto gpu_asm_opts = xla::gpu::PtxOptsFromConfig(config);
|
||||
for (uint32_t arch : architectures_) {
|
||||
int32_t cc_major = arch / 10;
|
||||
int32_t cc_minor = arch % 10;
|
||||
// Module may be changed by CompileToPtx.
|
||||
auto llvm_module_copy = llvm::CloneModule(*llvmModule);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::string ptx,
|
||||
xla::gpu::nvptx::CompileToPtx(llvm_module_copy.get(),
|
||||
std::make_pair(cc_major, cc_minor),
|
||||
config, libdevice_dir, enable_fusion));
|
||||
// TODO(b/169066682): If compute_XX profile, collect PTX image here.
|
||||
VLOG(1) << ptx;
|
||||
TF_ASSIGN_OR_RETURN(std::vector<uint8_t> gpu_asm,
|
||||
tensorflow::se::CompileGpuAsm(
|
||||
cc_major, cc_minor, ptx.c_str(), gpu_asm_opts));
|
||||
|
||||
return tensorflow::se::CompileGpuAsm(cc_major, cc_minor, ptx.c_str(),
|
||||
xla::gpu::PtxOptsFromConfig(config));
|
||||
if (!generate_fatbin_) {
|
||||
// Skip fatbin generation and return the first and only GPU machine
|
||||
// code.
|
||||
return gpu_asm;
|
||||
}
|
||||
|
||||
// Collect cubin image.
|
||||
images.push_back({absl::StrCat("sm_", arch), std::move(gpu_asm)});
|
||||
}
|
||||
|
||||
// TODO(b/169870789): Revisit the use of fatbins.
|
||||
// Bundle cubin and PTX images into a single fatbin.
|
||||
return tensorflow::se::BundleGpuAsm(images,
|
||||
gpu_asm_opts.preferred_cuda_dir);
|
||||
#endif
|
||||
|
||||
return InternalError(
|
||||
"Neither TENSORFLOW_USE_ROCM nor GOOGLE_CUDA are defined."
|
||||
" Did you specify either --config=rocm or --config=cuda ?");
|
||||
@ -141,8 +183,10 @@ class GpuKernelToBlobPass
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<gpu::GPUModuleOp>> CreateGpuKernelToBlobPass(
|
||||
mlir::StringRef blob_annotation, int32_t architecture) {
|
||||
return std::make_unique<GpuKernelToBlobPass>(blob_annotation, architecture);
|
||||
mlir::StringRef blob_annotation, ArrayRef<uint32_t> architectures,
|
||||
bool generate_fatbin) {
|
||||
return std::make_unique<GpuKernelToBlobPass>(blob_annotation, architectures,
|
||||
generate_fatbin);
|
||||
}
|
||||
|
||||
} // namespace transforms
|
||||
|
||||
@ -61,7 +61,8 @@ CreatePropagateTensorFlowABIKnowledgePass(
|
||||
|
||||
// Pass to annotate GPU Module with its PTX.
|
||||
std::unique_ptr<OperationPass<gpu::GPUModuleOp>> CreateGpuKernelToBlobPass(
|
||||
mlir::StringRef blob_annotation = "", int32_t architecture = 0);
|
||||
mlir::StringRef blob_annotation = "", ArrayRef<uint32_t> architectures = {},
|
||||
bool generate_fatbin = true);
|
||||
|
||||
// Pass to unfuse batch norm.
|
||||
std::unique_ptr<FunctionPass> CreateUnfuseBatchNormPass();
|
||||
|
||||
@ -53,7 +53,10 @@ def GpuKernelToBlobPass : Pass<"gpu-kernel-to-blob", "gpu::GPUModuleOp"> {
|
||||
let options = [
|
||||
Option<"blob_annotation_", "blob-annotation", "std::string",
|
||||
/*default=*/"", "Blob attribute name">,
|
||||
Option<"arch_", "arch", "int32_t", /*default=*/"0", "GPU architecture">,
|
||||
ListOption<"architectures_", "arch", "uint32_t", "GPU architectures">,
|
||||
Option<"generate_fatbin_", "generate-fatbin", "bool", /*default=*/"true",
|
||||
"Bundle machine code for the different architectures in one "
|
||||
"fatbin.">,
|
||||
];
|
||||
let constructor = "transforms::CreateGpuKernelToBlobPass()";
|
||||
}
|
||||
|
||||
@ -306,9 +306,6 @@ def _gen_unranked_kernel_fatbin_impl(ctx):
|
||||
archs_trimmed.append(arch[3:])
|
||||
arch_flag = ",".join(archs_trimmed)
|
||||
|
||||
# TODO(b/169066682): Generate Fatbin when lowering GPU module.
|
||||
arch_flag = "75"
|
||||
|
||||
filename = "%s.a" % (name)
|
||||
gpu_bin = ctx.outputs.output
|
||||
ctx.actions.run(
|
||||
|
||||
@ -106,6 +106,10 @@ cc_library(
|
||||
# an intermediate target.
|
||||
cc_library(name = "ptxas_wrapper")
|
||||
|
||||
# Buildozer can not remove dependencies inside select guards, so we have to use
|
||||
# an intermediate target.
|
||||
cc_library(name = "fatbinary_wrapper")
|
||||
|
||||
cc_library(
|
||||
name = "cuda_driver",
|
||||
srcs = if_cuda_is_configured(["cuda_driver.cc"]),
|
||||
|
||||
@ -251,6 +251,7 @@ cc_library(
|
||||
]) + if_cuda_is_configured([
|
||||
"//tensorflow/stream_executor/cuda:cuda_driver",
|
||||
"//tensorflow/stream_executor/cuda:ptxas_wrapper",
|
||||
"//tensorflow/stream_executor/cuda:fatbinary_wrapper",
|
||||
]),
|
||||
)
|
||||
|
||||
|
||||
@ -140,34 +140,44 @@ port::StatusOr<std::vector<uint8>> CompileGpuAsm(int device_ordinal,
|
||||
return CompileGpuAsm(cc_major, cc_minor, ptx_contents, options);
|
||||
}
|
||||
|
||||
port::StatusOr<std::vector<uint8>> CompileGpuAsm(int cc_major, int cc_minor,
|
||||
const char* ptx_contents,
|
||||
GpuAsmOpts options) {
|
||||
std::string ptxas_path;
|
||||
auto env = tensorflow::Env::Default();
|
||||
std::string ptxas_binary_name = "ptxas";
|
||||
static std::string findCudaExecutable(const std::string binary_name,
|
||||
const std::string preferred_cuda_dir) {
|
||||
#if defined(PLATFORM_WINDOWS)
|
||||
ptxas_binary_name += ".exe";
|
||||
const std::string binary_filename = binary_name + ".exe";
|
||||
#else
|
||||
const std::string& binary_filename = binary_name;
|
||||
#endif
|
||||
|
||||
// Search in cuda root candidates.
|
||||
auto env = tensorflow::Env::Default();
|
||||
std::string binary_path;
|
||||
for (const std::string& cuda_root :
|
||||
tensorflow::CandidateCudaRoots(options.preferred_cuda_dir)) {
|
||||
ptxas_path = tensorflow::io::JoinPath(cuda_root, "bin", ptxas_binary_name);
|
||||
VLOG(2) << "Looking for ptxas at " << ptxas_path;
|
||||
if (env->FileExists(ptxas_path).ok()) {
|
||||
tensorflow::CandidateCudaRoots(preferred_cuda_dir)) {
|
||||
binary_path = tensorflow::io::JoinPath(cuda_root, "bin", binary_filename);
|
||||
VLOG(2) << "Looking for " << binary_filename << " at " << binary_path;
|
||||
if (env->FileExists(binary_path).ok()) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!env->FileExists(ptxas_path).ok()) {
|
||||
if (!env->FileExists(binary_path).ok()) {
|
||||
// Rely on subprocess invocation to find the correct binary.
|
||||
ptxas_path = ptxas_binary_name;
|
||||
binary_path = binary_filename;
|
||||
}
|
||||
VLOG(2) << "Using ptxas at " << ptxas_path;
|
||||
VLOG(2) << "Using " << binary_filename << " at " << binary_path;
|
||||
return binary_path;
|
||||
}
|
||||
|
||||
port::StatusOr<std::vector<uint8>> CompileGpuAsm(int cc_major, int cc_minor,
|
||||
const char* ptx_contents,
|
||||
GpuAsmOpts options) {
|
||||
std::string ptxas_path =
|
||||
findCudaExecutable("ptxas", options.preferred_cuda_dir);
|
||||
|
||||
WarnIfBadPtxasVersion(ptxas_path);
|
||||
|
||||
// Write ptx into a temporary file.
|
||||
std::string ptx_path;
|
||||
auto env = tensorflow::Env::Default();
|
||||
if (!env->LocalTempFilename(&ptx_path)) {
|
||||
return port::InternalError("couldn't get temp PTX file name");
|
||||
}
|
||||
@ -232,4 +242,78 @@ port::StatusOr<std::vector<uint8>> CompileGpuAsm(int cc_major, int cc_minor,
|
||||
return cubin_vector;
|
||||
}
|
||||
|
||||
port::StatusOr<std::vector<uint8>> BundleGpuAsm(
|
||||
std::vector<CubinOrPTXImage> images, const std::string preferred_cuda_dir) {
|
||||
std::string fatbinary_path =
|
||||
findCudaExecutable("fatbinary", preferred_cuda_dir);
|
||||
|
||||
// Write images to temporary files.
|
||||
std::vector<std::string> image_paths;
|
||||
auto env = tensorflow::Env::Default();
|
||||
for (const CubinOrPTXImage& img : images) {
|
||||
std::string img_path;
|
||||
if (!env->LocalTempFilename(&img_path)) {
|
||||
return port::InternalError(
|
||||
"Could not get temporary filenames for images.");
|
||||
}
|
||||
TF_RETURN_IF_ERROR(tensorflow::WriteStringToFile(
|
||||
env, img_path, std::string(img.bytes.begin(), img.bytes.end())));
|
||||
VLOG(2) << "image written to " << img_path;
|
||||
image_paths.push_back(std::move(img_path));
|
||||
}
|
||||
auto image_files_cleaner = tensorflow::gtl::MakeCleanup([&image_paths] {
|
||||
for (const auto& path : image_paths) {
|
||||
TF_CHECK_OK(tensorflow::Env::Default()->DeleteFile(path));
|
||||
}
|
||||
});
|
||||
|
||||
// Prepare temorary result file.
|
||||
std::string result_path;
|
||||
if (!env->LocalTempFilename(&result_path)) {
|
||||
return port::InternalError(
|
||||
"Could not get temporary filename for fatbin result.");
|
||||
}
|
||||
auto result_file_cleaner = tensorflow::gtl::MakeCleanup([&result_path] {
|
||||
// This file may never be created, so the failure to delete it should not
|
||||
// propagate to TF.
|
||||
tensorflow::Env::Default()->DeleteFile(result_path).IgnoreError();
|
||||
});
|
||||
|
||||
// Invoke fatbinary and collect its output.
|
||||
tensorflow::SubProcess fatbinary;
|
||||
std::vector<std::string> fatbinary_args = {
|
||||
fatbinary_path, "--64", "--cmdline=--compile-only",
|
||||
"--link", "--compress-all", absl::StrCat("--create=", result_path)};
|
||||
assert(images.size() == image_paths.size());
|
||||
for (int i = 0; i < images.size(); i++) {
|
||||
fatbinary_args.push_back(absl::StrFormat(
|
||||
"--image=profile=%s,file=%s", images[i].profile, image_paths[i]));
|
||||
}
|
||||
if (VLOG_IS_ON(3)) {
|
||||
VLOG(3) << absl::StrJoin(fatbinary_args, " ");
|
||||
}
|
||||
fatbinary.SetProgram(fatbinary_path, fatbinary_args);
|
||||
fatbinary.SetChannelAction(tensorflow::CHAN_STDERR, tensorflow::ACTION_PIPE);
|
||||
if (!fatbinary.Start()) {
|
||||
return port::InternalError("Failed to launch fatbinary.");
|
||||
}
|
||||
std::string stderr_output;
|
||||
int exit_status = fatbinary.Communicate(
|
||||
/*stdin_input=*/nullptr, /*stdout_output=*/nullptr, &stderr_output);
|
||||
if (exit_status != 0) {
|
||||
return port::InternalError(absl::StrFormat(
|
||||
"fatbinary exited with non-zero error code %d, output: %s", exit_status,
|
||||
stderr_output));
|
||||
}
|
||||
if (!stderr_output.empty()) {
|
||||
VLOG(2) << stderr_output;
|
||||
}
|
||||
|
||||
// Read in the result and return it as a byte vector.
|
||||
std::string result_blob;
|
||||
TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(tensorflow::Env::Default(),
|
||||
result_path, &result_blob));
|
||||
return std::vector<uint8>(result_blob.begin(), result_blob.end());
|
||||
}
|
||||
|
||||
} // namespace stream_executor
|
||||
|
||||
@ -52,6 +52,16 @@ port::StatusOr<std::vector<uint8>> CompileGpuAsm(int cc_major, int cc_minor,
|
||||
port::StatusOr<absl::Span<const uint8>> CompileGpuAsmOrGetCached(
|
||||
int device_ordinal, const char* ptx, GpuAsmOpts compilation_options);
|
||||
|
||||
struct CubinOrPTXImage {
|
||||
std::string profile;
|
||||
std::vector<uint8> bytes;
|
||||
};
|
||||
|
||||
// Bundles the GPU machine code (cubins) and PTX if requested and returns the
|
||||
// resulting binary (i.e. a fatbin) as a byte array.
|
||||
port::StatusOr<std::vector<uint8>> BundleGpuAsm(
|
||||
std::vector<CubinOrPTXImage> images, const std::string preferred_cuda_dir);
|
||||
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_GPU_ASM_COMPILER_H_
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user