cuda:fatbinary dependency causes licensing issues.
PiperOrigin-RevId: 335809522 Change-Id: I374057f1b9b8144420089b04e2a5892f5c64832f
This commit is contained in:
parent
6876c21157
commit
4d77e555e8
@ -74,8 +74,8 @@ tool_names = [
|
|||||||
'tf_tfjs_translate', 'flatbuffer_to_string', 'flatbuffer_translate',
|
'tf_tfjs_translate', 'flatbuffer_to_string', 'flatbuffer_translate',
|
||||||
'tf-mlir-translate', 'mlir-tflite-runner', 'tfcompile',
|
'tf-mlir-translate', 'mlir-tflite-runner', 'tfcompile',
|
||||||
'json_to_flatbuffer', 'xla-gpu-opt', 'xla-mlir-gpu-opt', 'xla-opt',
|
'json_to_flatbuffer', 'xla-gpu-opt', 'xla-mlir-gpu-opt', 'xla-opt',
|
||||||
'hlo_to_llvm_ir', 'kernel-gen-opt', 'tf_to_gpu_binary', 'xla-thunks-opt',
|
'hlo_to_llvm_ir', 'kernel-gen-opt', 'tf_to_kernel', 'tf_to_gpu_binary',
|
||||||
'tfjs-opt'
|
'xla-thunks-opt', 'tfjs-opt'
|
||||||
]
|
]
|
||||||
tools = [ToolSubst(s, unresolved='ignore') for s in tool_names]
|
tools = [ToolSubst(s, unresolved='ignore') for s in tool_names]
|
||||||
llvm_config.add_tool_substitutions(tools, tool_dirs)
|
llvm_config.add_tool_substitutions(tools, tool_dirs)
|
||||||
|
|||||||
@ -105,7 +105,10 @@ tf_cc_binary(
|
|||||||
tf_cc_binary(
|
tf_cc_binary(
|
||||||
name = "tf_to_kernel",
|
name = "tf_to_kernel",
|
||||||
srcs = ["tf_to_kernel.cc"],
|
srcs = ["tf_to_kernel.cc"],
|
||||||
visibility = ["//tensorflow/core/kernels/mlir_generated:__pkg__"],
|
visibility = [
|
||||||
|
"//tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_kernel:__pkg__",
|
||||||
|
"//tensorflow/core/kernels/mlir_generated:__pkg__",
|
||||||
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":kernel_creator",
|
":kernel_creator",
|
||||||
"//tensorflow/compiler/mlir:init_mlir",
|
"//tensorflow/compiler/mlir:init_mlir",
|
||||||
|
|||||||
@ -174,7 +174,8 @@ Status LowerTFtoGPU(mlir::ModuleOp module, bool gpu_binary_only,
|
|||||||
Status LowerGPUToLLVM(mlir::ModuleOp module, bool gpu_binary_only,
|
Status LowerGPUToLLVM(mlir::ModuleOp module, bool gpu_binary_only,
|
||||||
llvm::ArrayRef<uint32_t> same_shape,
|
llvm::ArrayRef<uint32_t> same_shape,
|
||||||
llvm::StringRef gpu_binary_attr_name,
|
llvm::StringRef gpu_binary_attr_name,
|
||||||
int32_t architecture) {
|
llvm::ArrayRef<uint32_t> architectures,
|
||||||
|
bool generate_fatbin) {
|
||||||
mlir::PassManager pm(module.getContext());
|
mlir::PassManager pm(module.getContext());
|
||||||
applyTensorflowAndCLOptions(pm);
|
applyTensorflowAndCLOptions(pm);
|
||||||
|
|
||||||
@ -187,7 +188,7 @@ Status LowerGPUToLLVM(mlir::ModuleOp module, bool gpu_binary_only,
|
|||||||
}
|
}
|
||||||
kernel_pm.addPass(mlir::createStripDebugInfoPass());
|
kernel_pm.addPass(mlir::createStripDebugInfoPass());
|
||||||
kernel_pm.addPass(mlir::kernel_gen::transforms::CreateGpuKernelToBlobPass(
|
kernel_pm.addPass(mlir::kernel_gen::transforms::CreateGpuKernelToBlobPass(
|
||||||
gpu_binary_attr_name, architecture));
|
gpu_binary_attr_name, architectures, generate_fatbin));
|
||||||
|
|
||||||
if (!gpu_binary_only) {
|
if (!gpu_binary_only) {
|
||||||
pm.addPass(mlir::kernel_gen::transforms::CreateTFKernelToLLVMPass());
|
pm.addPass(mlir::kernel_gen::transforms::CreateTFKernelToLLVMPass());
|
||||||
@ -202,9 +203,9 @@ Status LowerGPUToLLVM(mlir::ModuleOp module, bool gpu_binary_only,
|
|||||||
|
|
||||||
StatusOr<mlir::OwningModuleRef> GenerateKernelForTfCode(
|
StatusOr<mlir::OwningModuleRef> GenerateKernelForTfCode(
|
||||||
mlir::MLIRContext& context, llvm::StringRef tf_code, bool gpu_binary_only,
|
mlir::MLIRContext& context, llvm::StringRef tf_code, bool gpu_binary_only,
|
||||||
int32_t architecture, llvm::ArrayRef<uint32_t> tile_sizes,
|
llvm::ArrayRef<uint32_t> architectures, llvm::ArrayRef<uint32_t> tile_sizes,
|
||||||
llvm::ArrayRef<uint32_t> same_shape,
|
llvm::ArrayRef<uint32_t> same_shape,
|
||||||
llvm::ArrayRef<uint32_t> unroll_factors) {
|
llvm::ArrayRef<uint32_t> unroll_factors, bool generate_fatbin) {
|
||||||
mlir::RegisterAllTensorFlowDialects(context.getDialectRegistry());
|
mlir::RegisterAllTensorFlowDialects(context.getDialectRegistry());
|
||||||
mlir::OwningModuleRef module = mlir::parseSourceString(tf_code, &context);
|
mlir::OwningModuleRef module = mlir::parseSourceString(tf_code, &context);
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
@ -221,7 +222,8 @@ StatusOr<mlir::OwningModuleRef> GenerateKernelForTfCode(
|
|||||||
TF_RETURN_IF_ERROR(xla::mlir_gpu::LowerKernelBodiesToNVVM(module.get()));
|
TF_RETURN_IF_ERROR(xla::mlir_gpu::LowerKernelBodiesToNVVM(module.get()));
|
||||||
#endif
|
#endif
|
||||||
TF_RETURN_IF_ERROR(LowerGPUToLLVM(module.get(), gpu_binary_only, same_shape,
|
TF_RETURN_IF_ERROR(LowerGPUToLLVM(module.get(), gpu_binary_only, same_shape,
|
||||||
kGpuBinaryAttrName, architecture));
|
kGpuBinaryAttrName, architectures,
|
||||||
|
generate_fatbin));
|
||||||
return module;
|
return module;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -38,9 +38,10 @@ namespace kernel_gen {
|
|||||||
// false, lowers the host side to LLVM Dialect.
|
// false, lowers the host side to LLVM Dialect.
|
||||||
xla::StatusOr<mlir::OwningModuleRef> GenerateKernelForTfCode(
|
xla::StatusOr<mlir::OwningModuleRef> GenerateKernelForTfCode(
|
||||||
mlir::MLIRContext& context, llvm::StringRef tf_code, bool gpu_binary_only,
|
mlir::MLIRContext& context, llvm::StringRef tf_code, bool gpu_binary_only,
|
||||||
int32_t architecture = 75, llvm::ArrayRef<uint32_t> tile_sizes = {16, 64},
|
llvm::ArrayRef<uint32_t> architectures = {75},
|
||||||
|
llvm::ArrayRef<uint32_t> tile_sizes = {16, 64},
|
||||||
llvm::ArrayRef<uint32_t> same_shape = {},
|
llvm::ArrayRef<uint32_t> same_shape = {},
|
||||||
llvm::ArrayRef<uint32_t> unroll_factors = {});
|
llvm::ArrayRef<uint32_t> unroll_factors = {}, bool generate_fatbin = true);
|
||||||
|
|
||||||
// Extracts gpu_binary from the converted module.
|
// Extracts gpu_binary from the converted module.
|
||||||
xla::StatusOr<std::string> ExtractGpuBinary(mlir::ModuleOp module);
|
xla::StatusOr<std::string> ExtractGpuBinary(mlir::ModuleOp module);
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
// RUN: tf_to_gpu_binary --input=%s --output=%t --same_shape=0,1 --unroll_factors=4 --tile_sizes=256 --arch=70
|
// RUN: tf_to_gpu_binary --input=%s --output=%t --same_shape=0,1 --unroll_factors=4 --tile_sizes=256 --arch=70
|
||||||
func @tanh(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
func @tanh(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||||
%0 = "tf.Tanh"(%arg0) { }
|
%0 = "tf.Tanh"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||||
: (tensor<?xf32>) -> tensor<?xf32>
|
|
||||||
return %0 : tensor<?xf32>
|
return %0 : tensor<?xf32>
|
||||||
}
|
}
|
||||||
|
|||||||
@ -0,0 +1,17 @@
|
|||||||
|
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
|
||||||
|
|
||||||
|
package(licenses = ["notice"])
|
||||||
|
|
||||||
|
glob_lit_tests(
|
||||||
|
data = [
|
||||||
|
"//tensorflow/compiler/mlir/tools/kernel_gen:tf_to_kernel",
|
||||||
|
"@llvm-project//mlir:run_lit.sh",
|
||||||
|
],
|
||||||
|
default_tags = [
|
||||||
|
# We need access to the CUDA SDK.
|
||||||
|
"gpu",
|
||||||
|
"no_rocm",
|
||||||
|
],
|
||||||
|
driver = "//tensorflow/compiler/mlir:run_lit.sh",
|
||||||
|
test_file_exts = ["mlir"],
|
||||||
|
)
|
||||||
@ -0,0 +1,6 @@
|
|||||||
|
// RUN: tf_to_kernel --input=%s --output=%t --same_shape=0,1 --unroll_factors=4 --tile_sizes=256 --arch=70,75
|
||||||
|
|
||||||
|
func @tanh(%arg: tensor<*xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "tf.Tanh"(%arg) : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
return %0 : tensor<*xf32>
|
||||||
|
}
|
||||||
@ -48,7 +48,7 @@ xla::Status Run(llvm::StringRef input_file, llvm::StringRef output_file,
|
|||||||
mlir::OwningModuleRef module,
|
mlir::OwningModuleRef module,
|
||||||
GenerateKernelForTfCode(context, tf_code, /*gpu_binary_only=*/true,
|
GenerateKernelForTfCode(context, tf_code, /*gpu_binary_only=*/true,
|
||||||
architecture, tile_sizes, same_shape,
|
architecture, tile_sizes, same_shape,
|
||||||
unroll_factors));
|
unroll_factors, /*generate_fatbin=*/false));
|
||||||
// Extract gpu_binary.
|
// Extract gpu_binary.
|
||||||
TF_ASSIGN_OR_RETURN(std::string gpu_binary, ExtractGpuBinary(*module));
|
TF_ASSIGN_OR_RETURN(std::string gpu_binary, ExtractGpuBinary(*module));
|
||||||
|
|
||||||
|
|||||||
@ -95,7 +95,8 @@ xla::StatusOr<std::string> EmitToBinary(mlir::ModuleOp module) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
xla::Status Run(llvm::StringRef input_file, llvm::StringRef output_file,
|
xla::Status Run(llvm::StringRef input_file, llvm::StringRef output_file,
|
||||||
int32_t architecture, llvm::ArrayRef<uint32_t> tile_sizes,
|
llvm::ArrayRef<uint32_t> architectures,
|
||||||
|
llvm::ArrayRef<uint32_t> tile_sizes,
|
||||||
llvm::ArrayRef<uint32_t> same_shape,
|
llvm::ArrayRef<uint32_t> same_shape,
|
||||||
llvm::ArrayRef<uint32_t> unroll_factors) {
|
llvm::ArrayRef<uint32_t> unroll_factors) {
|
||||||
// Read TF code.
|
// Read TF code.
|
||||||
@ -107,7 +108,7 @@ xla::Status Run(llvm::StringRef input_file, llvm::StringRef output_file,
|
|||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
mlir::OwningModuleRef module,
|
mlir::OwningModuleRef module,
|
||||||
GenerateKernelForTfCode(context, tf_code, /*gpu_binary_only=*/false,
|
GenerateKernelForTfCode(context, tf_code, /*gpu_binary_only=*/false,
|
||||||
architecture, tile_sizes, same_shape,
|
architectures, tile_sizes, same_shape,
|
||||||
unroll_factors));
|
unroll_factors));
|
||||||
// Get binary.
|
// Get binary.
|
||||||
TF_ASSIGN_OR_RETURN(std::string binary, EmitToBinary(*module));
|
TF_ASSIGN_OR_RETURN(std::string binary, EmitToBinary(*module));
|
||||||
@ -129,8 +130,8 @@ int main(int argc, char** argv) {
|
|||||||
llvm::cl::opt<std::string> output_file(
|
llvm::cl::opt<std::string> output_file(
|
||||||
"output", llvm::cl::desc("output file"), llvm::cl::value_desc("filename"),
|
"output", llvm::cl::desc("output file"), llvm::cl::value_desc("filename"),
|
||||||
llvm::cl::init("foo.bin"));
|
llvm::cl::init("foo.bin"));
|
||||||
llvm::cl::list<int32_t> architecture(
|
llvm::cl::list<uint32_t> architectures(
|
||||||
"arch", llvm::cl::desc("target architecture (e.g. 50 for sm_50)"),
|
"arch", llvm::cl::desc("target architectures (e.g. 50 for sm_50)"),
|
||||||
llvm::cl::OneOrMore, llvm::cl::CommaSeparated);
|
llvm::cl::OneOrMore, llvm::cl::CommaSeparated);
|
||||||
llvm::cl::list<uint32_t> tile_sizes(
|
llvm::cl::list<uint32_t> tile_sizes(
|
||||||
"tile_sizes", llvm::cl::desc("tile sizes to use"), llvm::cl::ZeroOrMore,
|
"tile_sizes", llvm::cl::desc("tile sizes to use"), llvm::cl::ZeroOrMore,
|
||||||
@ -151,7 +152,7 @@ int main(int argc, char** argv) {
|
|||||||
llvm::cl::ParseCommandLineOptions(argc, argv, "TF op GPU kernel generator\n");
|
llvm::cl::ParseCommandLineOptions(argc, argv, "TF op GPU kernel generator\n");
|
||||||
|
|
||||||
auto status =
|
auto status =
|
||||||
tensorflow::kernel_gen::Run(input_file, output_file, architecture.front(),
|
tensorflow::kernel_gen::Run(input_file, output_file, architectures,
|
||||||
tile_sizes, same_shape, unroll_factors);
|
tile_sizes, same_shape, unroll_factors);
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
LOG(ERROR) << status;
|
LOG(ERROR) << status;
|
||||||
|
|||||||
@ -117,6 +117,7 @@ cc_library(
|
|||||||
"@llvm-project//mlir:AllPassesAndDialects",
|
"@llvm-project//mlir:AllPassesAndDialects",
|
||||||
"@llvm-project//mlir:Support",
|
"@llvm-project//mlir:Support",
|
||||||
"@llvm-project//mlir:Transforms",
|
"@llvm-project//mlir:Transforms",
|
||||||
|
"@llvm-project//llvm:TransformUtils",
|
||||||
"//tensorflow/compiler/mlir/hlo",
|
"//tensorflow/compiler/mlir/hlo",
|
||||||
"//tensorflow/compiler/mlir/hlo:hlo_legalize_to_lhlo",
|
"//tensorflow/compiler/mlir/hlo:hlo_legalize_to_lhlo",
|
||||||
"//tensorflow/compiler/mlir/hlo:lhlo",
|
"//tensorflow/compiler/mlir/hlo:lhlo",
|
||||||
|
|||||||
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "llvm/Transforms/Utils/Cloning.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||||
#include "mlir/Target/NVVMIR.h" // from @llvm-project
|
#include "mlir/Target/NVVMIR.h" // from @llvm-project
|
||||||
#include "mlir/Target/ROCDLIR.h" // from @llvm-project
|
#include "mlir/Target/ROCDLIR.h" // from @llvm-project
|
||||||
@ -49,9 +50,12 @@ using xla::InternalError;
|
|||||||
class GpuKernelToBlobPass
|
class GpuKernelToBlobPass
|
||||||
: public GpuKernelToBlobPassBase<GpuKernelToBlobPass> {
|
: public GpuKernelToBlobPassBase<GpuKernelToBlobPass> {
|
||||||
public:
|
public:
|
||||||
GpuKernelToBlobPass(mlir::StringRef blob_annotation, int32_t arch) {
|
GpuKernelToBlobPass(mlir::StringRef blob_annotation,
|
||||||
|
llvm::ArrayRef<uint32_t> architectures,
|
||||||
|
bool generate_fatbin) {
|
||||||
blob_annotation_ = blob_annotation.str();
|
blob_annotation_ = blob_annotation.str();
|
||||||
arch_ = arch;
|
architectures_ = architectures;
|
||||||
|
generate_fatbin_ = generate_fatbin;
|
||||||
}
|
}
|
||||||
|
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
@ -69,7 +73,17 @@ class GpuKernelToBlobPass
|
|||||||
|
|
||||||
xla::StatusOr<std::vector<uint8_t>> GetGpuBinaryBlob(
|
xla::StatusOr<std::vector<uint8_t>> GetGpuBinaryBlob(
|
||||||
mlir::gpu::GPUModuleOp gpu_module) {
|
mlir::gpu::GPUModuleOp gpu_module) {
|
||||||
|
if (architectures_.empty()) {
|
||||||
|
return InternalError("Expected at least one GPU architecture.");
|
||||||
|
}
|
||||||
|
if (!generate_fatbin_ && architectures_.size() > 1) {
|
||||||
|
return InternalError(
|
||||||
|
"Can only generate machine code for more than one architecture as a "
|
||||||
|
"fatbin.");
|
||||||
|
}
|
||||||
|
|
||||||
llvm::LLVMContext llvmContext;
|
llvm::LLVMContext llvmContext;
|
||||||
|
|
||||||
#if TENSORFLOW_USE_ROCM
|
#if TENSORFLOW_USE_ROCM
|
||||||
auto llvmModule = mlir::translateModuleToROCDLIR(gpu_module, llvmContext);
|
auto llvmModule = mlir::translateModuleToROCDLIR(gpu_module, llvmContext);
|
||||||
if (!llvmModule) {
|
if (!llvmModule) {
|
||||||
@ -81,9 +95,14 @@ class GpuKernelToBlobPass
|
|||||||
xla::HloModuleConfig config;
|
xla::HloModuleConfig config;
|
||||||
config.set_debug_options(xla::GetDebugOptionsFromFlags());
|
config.set_debug_options(xla::GetDebugOptionsFromFlags());
|
||||||
|
|
||||||
std::string libdevice_dir = tensorflow::RocdlRoot();
|
// TODO(b/169066682): Support fatbin on ROCm.
|
||||||
|
if (generate_fatbin_) {
|
||||||
|
return InternalError("Fatbins are not yet supported for ROCm.");
|
||||||
|
}
|
||||||
|
|
||||||
return xla::gpu::amdgpu::CompileToHsaco(llvmModule.get(), arch_, config,
|
uint32_t arch = architectures_.front();
|
||||||
|
std::string libdevice_dir = tensorflow::RocdlRoot();
|
||||||
|
return xla::gpu::amdgpu::CompileToHsaco(llvmModule.get(), arch, config,
|
||||||
libdevice_dir);
|
libdevice_dir);
|
||||||
|
|
||||||
#elif GOOGLE_CUDA
|
#elif GOOGLE_CUDA
|
||||||
@ -102,19 +121,42 @@ class GpuKernelToBlobPass
|
|||||||
target->Options.AllowFPOpFusion = llvm::FPOpFusion::FPOpFusionMode::Fast;
|
target->Options.AllowFPOpFusion = llvm::FPOpFusion::FPOpFusionMode::Fast;
|
||||||
};
|
};
|
||||||
|
|
||||||
int32_t cc_major = arch_ / 10;
|
// Compile and collect requested cubin and PTX images.
|
||||||
int32_t cc_minor = arch_ % 10;
|
std::vector<tensorflow::se::CubinOrPTXImage> images;
|
||||||
TF_ASSIGN_OR_RETURN(std::string libdevice_dir, GetLibdeviceDir(config));
|
TF_ASSIGN_OR_RETURN(std::string libdevice_dir, GetLibdeviceDir(config));
|
||||||
TF_ASSIGN_OR_RETURN(
|
auto gpu_asm_opts = xla::gpu::PtxOptsFromConfig(config);
|
||||||
std::string ptx,
|
for (uint32_t arch : architectures_) {
|
||||||
xla::gpu::nvptx::CompileToPtx(llvmModule.get(),
|
int32_t cc_major = arch / 10;
|
||||||
std::make_pair(cc_major, cc_minor),
|
int32_t cc_minor = arch % 10;
|
||||||
config, libdevice_dir, enable_fusion));
|
// Module may be changed by CompileToPtx.
|
||||||
VLOG(1) << ptx;
|
auto llvm_module_copy = llvm::CloneModule(*llvmModule);
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
std::string ptx,
|
||||||
|
xla::gpu::nvptx::CompileToPtx(llvm_module_copy.get(),
|
||||||
|
std::make_pair(cc_major, cc_minor),
|
||||||
|
config, libdevice_dir, enable_fusion));
|
||||||
|
// TODO(b/169066682): If compute_XX profile, collect PTX image here.
|
||||||
|
VLOG(1) << ptx;
|
||||||
|
TF_ASSIGN_OR_RETURN(std::vector<uint8_t> gpu_asm,
|
||||||
|
tensorflow::se::CompileGpuAsm(
|
||||||
|
cc_major, cc_minor, ptx.c_str(), gpu_asm_opts));
|
||||||
|
|
||||||
return tensorflow::se::CompileGpuAsm(cc_major, cc_minor, ptx.c_str(),
|
if (!generate_fatbin_) {
|
||||||
xla::gpu::PtxOptsFromConfig(config));
|
// Skip fatbin generation and return the first and only GPU machine
|
||||||
|
// code.
|
||||||
|
return gpu_asm;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collect cubin image.
|
||||||
|
images.push_back({absl::StrCat("sm_", arch), std::move(gpu_asm)});
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(b/169870789): Revisit the use of fatbins.
|
||||||
|
// Bundle cubin and PTX images into a single fatbin.
|
||||||
|
return tensorflow::se::BundleGpuAsm(images,
|
||||||
|
gpu_asm_opts.preferred_cuda_dir);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
return InternalError(
|
return InternalError(
|
||||||
"Neither TENSORFLOW_USE_ROCM nor GOOGLE_CUDA are defined."
|
"Neither TENSORFLOW_USE_ROCM nor GOOGLE_CUDA are defined."
|
||||||
" Did you specify either --config=rocm or --config=cuda ?");
|
" Did you specify either --config=rocm or --config=cuda ?");
|
||||||
@ -141,8 +183,10 @@ class GpuKernelToBlobPass
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<gpu::GPUModuleOp>> CreateGpuKernelToBlobPass(
|
std::unique_ptr<OperationPass<gpu::GPUModuleOp>> CreateGpuKernelToBlobPass(
|
||||||
mlir::StringRef blob_annotation, int32_t architecture) {
|
mlir::StringRef blob_annotation, ArrayRef<uint32_t> architectures,
|
||||||
return std::make_unique<GpuKernelToBlobPass>(blob_annotation, architecture);
|
bool generate_fatbin) {
|
||||||
|
return std::make_unique<GpuKernelToBlobPass>(blob_annotation, architectures,
|
||||||
|
generate_fatbin);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace transforms
|
} // namespace transforms
|
||||||
|
|||||||
@ -61,7 +61,8 @@ CreatePropagateTensorFlowABIKnowledgePass(
|
|||||||
|
|
||||||
// Pass to annotate GPU Module with its PTX.
|
// Pass to annotate GPU Module with its PTX.
|
||||||
std::unique_ptr<OperationPass<gpu::GPUModuleOp>> CreateGpuKernelToBlobPass(
|
std::unique_ptr<OperationPass<gpu::GPUModuleOp>> CreateGpuKernelToBlobPass(
|
||||||
mlir::StringRef blob_annotation = "", int32_t architecture = 0);
|
mlir::StringRef blob_annotation = "", ArrayRef<uint32_t> architectures = {},
|
||||||
|
bool generate_fatbin = true);
|
||||||
|
|
||||||
// Pass to unfuse batch norm.
|
// Pass to unfuse batch norm.
|
||||||
std::unique_ptr<FunctionPass> CreateUnfuseBatchNormPass();
|
std::unique_ptr<FunctionPass> CreateUnfuseBatchNormPass();
|
||||||
|
|||||||
@ -53,7 +53,10 @@ def GpuKernelToBlobPass : Pass<"gpu-kernel-to-blob", "gpu::GPUModuleOp"> {
|
|||||||
let options = [
|
let options = [
|
||||||
Option<"blob_annotation_", "blob-annotation", "std::string",
|
Option<"blob_annotation_", "blob-annotation", "std::string",
|
||||||
/*default=*/"", "Blob attribute name">,
|
/*default=*/"", "Blob attribute name">,
|
||||||
Option<"arch_", "arch", "int32_t", /*default=*/"0", "GPU architecture">,
|
ListOption<"architectures_", "arch", "uint32_t", "GPU architectures">,
|
||||||
|
Option<"generate_fatbin_", "generate-fatbin", "bool", /*default=*/"true",
|
||||||
|
"Bundle machine code for the different architectures in one "
|
||||||
|
"fatbin.">,
|
||||||
];
|
];
|
||||||
let constructor = "transforms::CreateGpuKernelToBlobPass()";
|
let constructor = "transforms::CreateGpuKernelToBlobPass()";
|
||||||
}
|
}
|
||||||
|
|||||||
@ -306,9 +306,6 @@ def _gen_unranked_kernel_fatbin_impl(ctx):
|
|||||||
archs_trimmed.append(arch[3:])
|
archs_trimmed.append(arch[3:])
|
||||||
arch_flag = ",".join(archs_trimmed)
|
arch_flag = ",".join(archs_trimmed)
|
||||||
|
|
||||||
# TODO(b/169066682): Generate Fatbin when lowering GPU module.
|
|
||||||
arch_flag = "75"
|
|
||||||
|
|
||||||
filename = "%s.a" % (name)
|
filename = "%s.a" % (name)
|
||||||
gpu_bin = ctx.outputs.output
|
gpu_bin = ctx.outputs.output
|
||||||
ctx.actions.run(
|
ctx.actions.run(
|
||||||
|
|||||||
@ -106,6 +106,10 @@ cc_library(
|
|||||||
# an intermediate target.
|
# an intermediate target.
|
||||||
cc_library(name = "ptxas_wrapper")
|
cc_library(name = "ptxas_wrapper")
|
||||||
|
|
||||||
|
# Buildozer can not remove dependencies inside select guards, so we have to use
|
||||||
|
# an intermediate target.
|
||||||
|
cc_library(name = "fatbinary_wrapper")
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "cuda_driver",
|
name = "cuda_driver",
|
||||||
srcs = if_cuda_is_configured(["cuda_driver.cc"]),
|
srcs = if_cuda_is_configured(["cuda_driver.cc"]),
|
||||||
|
|||||||
@ -251,6 +251,7 @@ cc_library(
|
|||||||
]) + if_cuda_is_configured([
|
]) + if_cuda_is_configured([
|
||||||
"//tensorflow/stream_executor/cuda:cuda_driver",
|
"//tensorflow/stream_executor/cuda:cuda_driver",
|
||||||
"//tensorflow/stream_executor/cuda:ptxas_wrapper",
|
"//tensorflow/stream_executor/cuda:ptxas_wrapper",
|
||||||
|
"//tensorflow/stream_executor/cuda:fatbinary_wrapper",
|
||||||
]),
|
]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -140,34 +140,44 @@ port::StatusOr<std::vector<uint8>> CompileGpuAsm(int device_ordinal,
|
|||||||
return CompileGpuAsm(cc_major, cc_minor, ptx_contents, options);
|
return CompileGpuAsm(cc_major, cc_minor, ptx_contents, options);
|
||||||
}
|
}
|
||||||
|
|
||||||
port::StatusOr<std::vector<uint8>> CompileGpuAsm(int cc_major, int cc_minor,
|
static std::string findCudaExecutable(const std::string binary_name,
|
||||||
const char* ptx_contents,
|
const std::string preferred_cuda_dir) {
|
||||||
GpuAsmOpts options) {
|
|
||||||
std::string ptxas_path;
|
|
||||||
auto env = tensorflow::Env::Default();
|
|
||||||
std::string ptxas_binary_name = "ptxas";
|
|
||||||
#if defined(PLATFORM_WINDOWS)
|
#if defined(PLATFORM_WINDOWS)
|
||||||
ptxas_binary_name += ".exe";
|
const std::string binary_filename = binary_name + ".exe";
|
||||||
|
#else
|
||||||
|
const std::string& binary_filename = binary_name;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
// Search in cuda root candidates.
|
||||||
|
auto env = tensorflow::Env::Default();
|
||||||
|
std::string binary_path;
|
||||||
for (const std::string& cuda_root :
|
for (const std::string& cuda_root :
|
||||||
tensorflow::CandidateCudaRoots(options.preferred_cuda_dir)) {
|
tensorflow::CandidateCudaRoots(preferred_cuda_dir)) {
|
||||||
ptxas_path = tensorflow::io::JoinPath(cuda_root, "bin", ptxas_binary_name);
|
binary_path = tensorflow::io::JoinPath(cuda_root, "bin", binary_filename);
|
||||||
VLOG(2) << "Looking for ptxas at " << ptxas_path;
|
VLOG(2) << "Looking for " << binary_filename << " at " << binary_path;
|
||||||
if (env->FileExists(ptxas_path).ok()) {
|
if (env->FileExists(binary_path).ok()) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!env->FileExists(ptxas_path).ok()) {
|
if (!env->FileExists(binary_path).ok()) {
|
||||||
// Rely on subprocess invocation to find the correct binary.
|
// Rely on subprocess invocation to find the correct binary.
|
||||||
ptxas_path = ptxas_binary_name;
|
binary_path = binary_filename;
|
||||||
}
|
}
|
||||||
VLOG(2) << "Using ptxas at " << ptxas_path;
|
VLOG(2) << "Using " << binary_filename << " at " << binary_path;
|
||||||
|
return binary_path;
|
||||||
|
}
|
||||||
|
|
||||||
|
port::StatusOr<std::vector<uint8>> CompileGpuAsm(int cc_major, int cc_minor,
|
||||||
|
const char* ptx_contents,
|
||||||
|
GpuAsmOpts options) {
|
||||||
|
std::string ptxas_path =
|
||||||
|
findCudaExecutable("ptxas", options.preferred_cuda_dir);
|
||||||
|
|
||||||
WarnIfBadPtxasVersion(ptxas_path);
|
WarnIfBadPtxasVersion(ptxas_path);
|
||||||
|
|
||||||
// Write ptx into a temporary file.
|
// Write ptx into a temporary file.
|
||||||
std::string ptx_path;
|
std::string ptx_path;
|
||||||
|
auto env = tensorflow::Env::Default();
|
||||||
if (!env->LocalTempFilename(&ptx_path)) {
|
if (!env->LocalTempFilename(&ptx_path)) {
|
||||||
return port::InternalError("couldn't get temp PTX file name");
|
return port::InternalError("couldn't get temp PTX file name");
|
||||||
}
|
}
|
||||||
@ -232,4 +242,78 @@ port::StatusOr<std::vector<uint8>> CompileGpuAsm(int cc_major, int cc_minor,
|
|||||||
return cubin_vector;
|
return cubin_vector;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
port::StatusOr<std::vector<uint8>> BundleGpuAsm(
|
||||||
|
std::vector<CubinOrPTXImage> images, const std::string preferred_cuda_dir) {
|
||||||
|
std::string fatbinary_path =
|
||||||
|
findCudaExecutable("fatbinary", preferred_cuda_dir);
|
||||||
|
|
||||||
|
// Write images to temporary files.
|
||||||
|
std::vector<std::string> image_paths;
|
||||||
|
auto env = tensorflow::Env::Default();
|
||||||
|
for (const CubinOrPTXImage& img : images) {
|
||||||
|
std::string img_path;
|
||||||
|
if (!env->LocalTempFilename(&img_path)) {
|
||||||
|
return port::InternalError(
|
||||||
|
"Could not get temporary filenames for images.");
|
||||||
|
}
|
||||||
|
TF_RETURN_IF_ERROR(tensorflow::WriteStringToFile(
|
||||||
|
env, img_path, std::string(img.bytes.begin(), img.bytes.end())));
|
||||||
|
VLOG(2) << "image written to " << img_path;
|
||||||
|
image_paths.push_back(std::move(img_path));
|
||||||
|
}
|
||||||
|
auto image_files_cleaner = tensorflow::gtl::MakeCleanup([&image_paths] {
|
||||||
|
for (const auto& path : image_paths) {
|
||||||
|
TF_CHECK_OK(tensorflow::Env::Default()->DeleteFile(path));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Prepare temorary result file.
|
||||||
|
std::string result_path;
|
||||||
|
if (!env->LocalTempFilename(&result_path)) {
|
||||||
|
return port::InternalError(
|
||||||
|
"Could not get temporary filename for fatbin result.");
|
||||||
|
}
|
||||||
|
auto result_file_cleaner = tensorflow::gtl::MakeCleanup([&result_path] {
|
||||||
|
// This file may never be created, so the failure to delete it should not
|
||||||
|
// propagate to TF.
|
||||||
|
tensorflow::Env::Default()->DeleteFile(result_path).IgnoreError();
|
||||||
|
});
|
||||||
|
|
||||||
|
// Invoke fatbinary and collect its output.
|
||||||
|
tensorflow::SubProcess fatbinary;
|
||||||
|
std::vector<std::string> fatbinary_args = {
|
||||||
|
fatbinary_path, "--64", "--cmdline=--compile-only",
|
||||||
|
"--link", "--compress-all", absl::StrCat("--create=", result_path)};
|
||||||
|
assert(images.size() == image_paths.size());
|
||||||
|
for (int i = 0; i < images.size(); i++) {
|
||||||
|
fatbinary_args.push_back(absl::StrFormat(
|
||||||
|
"--image=profile=%s,file=%s", images[i].profile, image_paths[i]));
|
||||||
|
}
|
||||||
|
if (VLOG_IS_ON(3)) {
|
||||||
|
VLOG(3) << absl::StrJoin(fatbinary_args, " ");
|
||||||
|
}
|
||||||
|
fatbinary.SetProgram(fatbinary_path, fatbinary_args);
|
||||||
|
fatbinary.SetChannelAction(tensorflow::CHAN_STDERR, tensorflow::ACTION_PIPE);
|
||||||
|
if (!fatbinary.Start()) {
|
||||||
|
return port::InternalError("Failed to launch fatbinary.");
|
||||||
|
}
|
||||||
|
std::string stderr_output;
|
||||||
|
int exit_status = fatbinary.Communicate(
|
||||||
|
/*stdin_input=*/nullptr, /*stdout_output=*/nullptr, &stderr_output);
|
||||||
|
if (exit_status != 0) {
|
||||||
|
return port::InternalError(absl::StrFormat(
|
||||||
|
"fatbinary exited with non-zero error code %d, output: %s", exit_status,
|
||||||
|
stderr_output));
|
||||||
|
}
|
||||||
|
if (!stderr_output.empty()) {
|
||||||
|
VLOG(2) << stderr_output;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read in the result and return it as a byte vector.
|
||||||
|
std::string result_blob;
|
||||||
|
TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(tensorflow::Env::Default(),
|
||||||
|
result_path, &result_blob));
|
||||||
|
return std::vector<uint8>(result_blob.begin(), result_blob.end());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace stream_executor
|
} // namespace stream_executor
|
||||||
|
|||||||
@ -52,6 +52,16 @@ port::StatusOr<std::vector<uint8>> CompileGpuAsm(int cc_major, int cc_minor,
|
|||||||
port::StatusOr<absl::Span<const uint8>> CompileGpuAsmOrGetCached(
|
port::StatusOr<absl::Span<const uint8>> CompileGpuAsmOrGetCached(
|
||||||
int device_ordinal, const char* ptx, GpuAsmOpts compilation_options);
|
int device_ordinal, const char* ptx, GpuAsmOpts compilation_options);
|
||||||
|
|
||||||
|
struct CubinOrPTXImage {
|
||||||
|
std::string profile;
|
||||||
|
std::vector<uint8> bytes;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Bundles the GPU machine code (cubins) and PTX if requested and returns the
|
||||||
|
// resulting binary (i.e. a fatbin) as a byte array.
|
||||||
|
port::StatusOr<std::vector<uint8>> BundleGpuAsm(
|
||||||
|
std::vector<CubinOrPTXImage> images, const std::string preferred_cuda_dir);
|
||||||
|
|
||||||
} // namespace stream_executor
|
} // namespace stream_executor
|
||||||
|
|
||||||
#endif // TENSORFLOW_STREAM_EXECUTOR_GPU_ASM_COMPILER_H_
|
#endif // TENSORFLOW_STREAM_EXECUTOR_GPU_ASM_COMPILER_H_
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user