From 6bc075083ba07d5b1c92527230366e5a6d685490 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Thu, 23 Apr 2020 01:31:33 -0700 Subject: [PATCH] Split MlirCompiler into two classes. The base class is used as a way to interface with the MlirCompiler. The implementation class hides the cuda specific dependencies. With this change, we don't need to put xla-opt-main into a if_cuda_is_configured section. PiperOrigin-RevId: 308001425 Change-Id: Id59fe7ceaed5c03efef96608cdeb8e455a257a26 --- tensorflow/compiler/xla/service/BUILD | 2 +- .../compiler/xla/service/mlir_gpu/BUILD | 29 +- .../xla/service/mlir_gpu/mlir_compiler.cc | 531 +--------------- .../xla/service/mlir_gpu/mlir_compiler.h | 30 +- .../service/mlir_gpu/mlir_compiler_impl.cc | 584 ++++++++++++++++++ 5 files changed, 611 insertions(+), 565 deletions(-) create mode 100644 tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index aac16fb723f..aef215e23e8 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -977,7 +977,7 @@ cc_library( "//tensorflow/compiler/xla/service/gpu:gpu_transfer_manager", "//tensorflow/core:stream_executor_no_cuda", ] + if_cuda_is_configured([ - "//tensorflow/compiler/xla/service/mlir_gpu:mlir_compiler", + "//tensorflow/compiler/xla/service/mlir_gpu:mlir_compiler_impl", ]), ) diff --git a/tensorflow/compiler/xla/service/mlir_gpu/BUILD b/tensorflow/compiler/xla/service/mlir_gpu/BUILD index 82016be79a9..cd679f7412e 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/BUILD +++ b/tensorflow/compiler/xla/service/mlir_gpu/BUILD @@ -59,11 +59,26 @@ cc_library( cc_library( name = "mlir_compiler", - srcs = if_cuda_is_configured(["mlir_compiler.cc"]), - hdrs = if_cuda_is_configured(["mlir_compiler.h"]), - deps = if_cuda_is_configured([ + srcs = ["mlir_compiler.cc"], + hdrs = ["mlir_compiler.h"], + deps = [ ":emission_context", + "//tensorflow/compiler/xla/service:compiler", + "//tensorflow/compiler/xla/service/gpu:target_constants", + "//tensorflow/core:stream_executor_no_cuda", + "@llvm-project//llvm:core", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMDialect", + ], +) + +cc_library( + name = "mlir_compiler_impl", + srcs = if_cuda_is_configured(["mlir_compiler_impl.cc"]), + deps = if_cuda_is_configured([ + ":mlir_compiler", ":failover_compiler", + ":emission_context", ":kernel_lowering", ":lhlo_dialect_emitter", "@com_google_absl//absl/container:flat_hash_map", @@ -77,7 +92,6 @@ cc_library( "@llvm-project//mlir:TargetNVVMIR", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:buffer_assignment", - "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:dump", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service/gpu:gpu_constants", @@ -93,7 +107,6 @@ cc_library( "//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend", "//tensorflow/core:cuda_libdevice_path", "//tensorflow/core:lib", - "//tensorflow/stream_executor:stream_executor_headers", "//tensorflow/stream_executor/gpu:asm_compiler", ]), alwayslink = True, # Contains compiler registration @@ -186,8 +199,8 @@ cc_library( cc_library( name = "xla_gpu_opt_lib", testonly = True, - srcs = if_cuda_is_configured(["xla_gpu_opt.cc"]), - hdrs = if_cuda_is_configured(["xla_gpu_opt.h"]), + srcs = ["xla_gpu_opt.cc"], + hdrs = ["xla_gpu_opt.h"], tags = ["no_pip"], deps = [ ":failover_compiler", @@ -212,7 +225,7 @@ cc_library( tf_cc_binary( name = "xla-gpu-opt", testonly = True, - srcs = if_cuda_is_configured(["xla_gpu_opt_main.cc"]), + srcs = ["xla_gpu_opt_main.cc"], tags = ["no_pip"], deps = [ ":mlir_compiler", diff --git a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc index dc33be5341c..458522f89e6 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc @@ -17,69 +17,18 @@ limitations under the License. #include -#include "absl/container/flat_hash_map.h" -#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" // from @llvm-project -#include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project +#include "llvm/IR/Module.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project -#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project -#include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/Function.h" // from @llvm-project -#include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/Module.h" // from @llvm-project -#include "mlir/IR/OperationSupport.h" // from @llvm-project -#include "mlir/IR/StandardTypes.h" // from @llvm-project -#include "mlir/IR/Value.h" // from @llvm-project -#include "mlir/Support/LLVM.h" // from @llvm-project -#include "mlir/Target/NVVMIR.h" // from @llvm-project -#include "tensorflow/compiler/xla/service/buffer_assignment.h" -#include "tensorflow/compiler/xla/service/dump.h" -#include "tensorflow/compiler/xla/service/gpu/gpu_constants.h" -#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" -#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h" -#include "tensorflow/compiler/xla/service/gpu/gpu_types.h" -#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" -#include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h" -#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" -#include "tensorflow/compiler/xla/service/gpu/nvptx_compiler.h" -#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" -#include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" -#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" #include "tensorflow/compiler/xla/service/gpu/target_constants.h" -#include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/mlir_gpu/emission_context.h" -#include "tensorflow/compiler/xla/service/mlir_gpu/failover_compiler.h" -#include "tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h" -#include "tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h" -#include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/platform/cuda_libdevice_path.h" -#include "tensorflow/stream_executor/gpu/asm_compiler.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { namespace mlir_gpu { namespace { -using ::mlir::BlockArgument; -using ::mlir::dyn_cast; -using ::mlir::FuncOp; using ::mlir::MLIRContext; -using ::mlir::ModuleOp; -using ::mlir::OwningModuleRef; -using ::mlir::UnknownLoc; -using ::mlir::Value; -using ::mlir::gpu::LaunchFuncOp; using ::mlir::LLVM::LLVMDialect; -using ::mlir::LLVM::LLVMFuncOp; -using ::mlir::LLVM::LLVMType; -using ::xla::gpu::GpuExecutable; -using ::xla::gpu::GpuHloSchedule; -using ::xla::gpu::GpuVersion; -using ::xla::gpu::StreamAssignment; -using ::xla::gpu::ThunkSchedule; int64 ConfigureLLVMModuleAndGetPointerSize(MLIRContext* context) { LLVMDialect* dialect = context->getRegisteredDialect(); @@ -89,49 +38,6 @@ int64 ConfigureLLVMModuleAndGetPointerSize(MLIRContext* context) { return module.getDataLayout().getPointerSize(); } -// TODO(b/137624192) Share with NVPTX compiler -static std::vector CandidateCudaRoots( - const HloModuleConfig& config) { - return tensorflow::CandidateCudaRoots( - config.debug_options().xla_gpu_cuda_data_dir()); -} - -void PrintCantFindCudaMessage(absl::string_view msg, - const HloModuleConfig& hlo_module_config) { - LOG(WARNING) << msg; - LOG(WARNING) << "Searched for CUDA in the following directories:"; - - for (const auto& dir : CandidateCudaRoots(hlo_module_config)) { - LOG(WARNING) << " " << dir; - } - LOG(WARNING) - << "You can choose the search directory by setting xla_gpu_cuda_data_dir " - "in HloModule's DebugOptions. For most apps, setting the environment " - "variable XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda will work."; -} - -// Returns the directory containing nvvm libdevice files. -string GetLibdeviceDir(const HloModuleConfig& hlo_module_config) { - for (const string& cuda_root : CandidateCudaRoots(hlo_module_config)) { - const string libdevice_dir = - tensorflow::io::JoinPath(cuda_root, "nvvm", "libdevice"); - VLOG(2) << "Looking for libdevice at " << libdevice_dir; - if (tensorflow::Env::Default()->IsDirectory(libdevice_dir).ok()) { - VLOG(2) << "Found libdevice dir " << libdevice_dir; - return libdevice_dir; - } - } - PrintCantFindCudaMessage( - "Can't find libdevice directory ${CUDA_DIR}/nvvm/libdevice. This may " - "result in compilation or runtime failures, if the program we try to run " - "uses routines from libdevice.", - hlo_module_config); - - // GetCudaRootCandidates always includes ".", but if everything fails, we - // return it anyway. Better than returning the empty string. - return "."; -} - } // namespace MlirCompiler::MlirCompiler() @@ -141,428 +47,6 @@ se::Platform::Id MlirCompiler::PlatformId() const { return stream_executor::cuda::kCudaPlatformId; } -StatusOr> MlirCompiler::RunHloPasses( - std::unique_ptr module, se::StreamExecutor* stream_exec, - se::DeviceMemoryAllocator* device_allocator) { - // Until we find a reason to do something different, run the same passes - // that the normal GPU backend runs. - gpu::NVPTXCompiler xla_compiler; - TF_RETURN_IF_ERROR(xla_compiler.OptimizeHloModule(module.get(), stream_exec, - device_allocator)); - TF_RETURN_IF_ERROR(xla_compiler.PrepareHloModuleForIrEmitting(module.get())); - - return std::move(module); -} - -namespace { - -// TODO(b/137624192): Move this to custom call handling and share. -absl::optional CanShareBufferHint(const HloInstruction* user, - const HloInstruction* operand, - const ShapeIndex& user_index) { - if (user->opcode() == HloOpcode::kCustomCall) { - // Share the bias buffer with the parent instruction. - if (user->custom_call_target() == xla::gpu::kGemmCallTarget) { - if (user->operand_count() == 3 && user->operand(2) == operand) { - return true; - } - } - // The operand of cholesky can be shared with the first output. - if (user->custom_call_target() == xla::gpu::kCusolverCholeskyCallTarget) { - return user_index.size() == 1 && user_index[0] == 0; - } - } - return absl::nullopt; -} - -// TODO(b/137624192): Share this with nvptx backend. -GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) { - int cc_major, cc_minor; - const auto& device_description = stream_exec->GetDeviceDescription(); - if (!device_description.cuda_compute_capability(&cc_major, &cc_minor)) { - LOG(WARNING) - << "Couldn't get compute capability for device; assuming sm_20."; - cc_major = 2; - cc_minor = 0; - } - return std::make_pair(cc_major, cc_minor); -} - -// Return the constant launch bound along the "x" dimension in "dim" if all the -// other dimensions are 1. Return nullopt otherwise or when any of the bounds -// is not constant. -static absl::optional getLaunchBound(const mlir::gpu::KernelDim3& dim) { - auto get_constant = [](mlir::Operation* op, - mlir::StringRef name) -> absl::optional { - if (auto constant = llvm::dyn_cast_or_null(op)) { - return constant.value().cast().getInt(); - } - op->emitError() << "bound " << name << " is not constant"; - return absl::nullopt; - }; - auto y_op = dim.y.getDefiningOp(); - auto dim_y = get_constant(y_op, "y"); - if (!dim_y.has_value() || dim_y.value() != 1) { - y_op->emitError() << "bound 'y' is not constant 1"; - return absl::nullopt; - } - auto z_op = dim.z.getDefiningOp(); - auto dim_z = get_constant(z_op, "z"); - if (!dim_z.has_value() || dim_z.value() != 1) { - z_op->emitError() << "bound 'z' is not constant 1"; - return absl::nullopt; - } - return get_constant(dim.x.getDefiningOp(), "x"); -} - -namespace { - -// Indexes of a range of arguments in a GPU function. This is used to keep the -// range of arguments that correspond to a lowered kernel argument of -// (previously) memref type. -struct LaunchFuncArgument { - int kernel_argument_begin; - int kernel_argument_size; -}; - -} // end namespace - -using OperandToValueMap = - absl::flat_hash_map>; - -static StatusOr> ComputeOperandToValueMap( - OperandToValueMap* operand_to_value_map, const HloInstruction* instr, - LaunchFuncOp launchOp, LLVMFuncOp kernel) { - auto operands = instr->operands(); - std::vector ordered_operands; - bool has_failed = false; - // A memref will expand into multiple kernel operands, accumulate their number - // in order to find them later. - int cur_operand_position = 0; - - for (int kernel_index = 0; kernel_index < launchOp.getNumKernelOperands(); - ++kernel_index) { - auto launchop_operand = - launchOp.getKernelOperand(kernel_index).dyn_cast(); - if (!launchop_operand) { - launchOp.emitError("argument to kernel is not a function input"); - has_failed = true; - continue; - } - auto memref_type = - launchop_operand.getType().dyn_cast<::mlir::MemRefType>(); - if (!memref_type) { - launchOp.emitError("only memref-typed arguments are supported"); - has_failed = true; - break; - } - // host_index is the argument position to the surrounding function that - // contains the launch. This index corresponds to HLO operand indices - // by construction. - auto host_index = launchop_operand.getArgNumber(); - // The trailing argument to the outer function are the results. - auto operand = - (host_index < operands.size()) ? operands[host_index] : instr; - if (!operand_to_value_map->count(operand)) { - ordered_operands.push_back(operand); - } - // Associate the HLO operand with the argument values of the kernel - // function. - int num_unpacked = - mlir::MemRefDescriptor::getNumUnpackedValues(memref_type); - (*operand_to_value_map)[operand].push_back( - {cur_operand_position, num_unpacked}); - cur_operand_position += num_unpacked; - } - if (has_failed) { - return InternalError("Mapping operands to kernel arguments has failed."); - } - return ordered_operands; -} - -Status InsertBufferLoadPreduleIntoKernel( - LLVMFuncOp kernel, const OperandToValueMap& operand_to_value_map, - const std::vector& ordered_operands, - BufferAssignment* assignment, - const std::vector& buffers) { - mlir::OpBuilder builder(kernel.getBody()); - auto llvm_dialect = kernel.getContext()->getRegisteredDialect(); - auto offset_type = LLVMType::getInt64Ty(llvm_dialect); - auto ptr_type = LLVMType::getInt8PtrTy(llvm_dialect); - auto void_type = LLVMType::getVoidTy(llvm_dialect); - auto loc = kernel.getLoc(); - - auto num_original_args = kernel.getNumArguments(); - std::vector new_arg_types(buffers.size(), ptr_type); - kernel.setAttr(kernel.getTypeAttrName(), - mlir::TypeAttr::get(LLVMType::getFunctionTy( - void_type, new_arg_types, /*isVarArg=*/false))); - std::vector original_args(kernel.args_begin(), kernel.args_end()); - - std::vector as_mlir_types(new_arg_types.begin(), - new_arg_types.end()); - auto new_args = kernel.front().addArguments(as_mlir_types); - std::vector buffer_args(new_args.begin(), new_args.end()); - - for (auto operand : ordered_operands) { - TF_ASSIGN_OR_RETURN(auto slice, - assignment->GetUniqueTopLevelSlice(operand)); - auto buffer = std::find(buffers.begin(), buffers.end(), slice.allocation()); - auto index = buffer - buffers.begin(); - auto offset = builder.create( - loc, offset_type, builder.getI64IntegerAttr(slice.offset())); - auto ptr = buffer_args[index]; - - // Replace uses of function arguments pertaining to memref descriptors with - // values derived from HLO buffers. The instructions inserting these values - // into memref descriptors were already introduced during the lowering phase - // as per MLIR calling convention. - for (auto arg : operand_to_value_map.at(operand)) { - mlir::MemRefDescriptorView original( - mlir::ValueRange(original_args) - .slice(arg.kernel_argument_begin, arg.kernel_argument_size)); - - // Allocated and aligned pointers are the same. - auto casted = builder.create( - loc, original.alignedPtr().getType().cast(), - mlir::ValueRange(ptr)); - original.alignedPtr().replaceAllUsesWith(casted); - original.allocatedPtr().replaceAllUsesWith(casted); - - // Use the offset of the HLO buffer instead of the one expected in the - // function call. - original.offset().replaceAllUsesWith(offset); - - // Fill the shape. - auto shape = operand->shape(); - // Unless the operand is a scalar pointer, also fill shape and strides. - if (shape.dimensions().empty()) { - continue; - } - - // TODO(b/137624192) Pass in the descriptor to allow for dynamic shapes. - assert(shape.IsArray() && shape.is_static()); - for (auto extent : llvm::enumerate(shape.dimensions())) { - auto shape = builder.create( - loc, original.size(extent.index()).getType(), - builder.getI64IntegerAttr(extent.value())); - original.size(extent.index()).replaceAllUsesWith(shape); - } - // Finally, fill the strides. - // TODO(b/137624192): Take assigned layout into account. - uint64_t accumulator = 0; - for (int64_t idx = shape.rank() - 1; idx >= 0; --idx) { - if (accumulator == 0) { - accumulator = 1; - } else { - accumulator *= shape.dimensions(idx + 1); - } - auto stride = builder.create( - loc, original.stride(idx).getType(), - builder.getI64IntegerAttr(accumulator)); - original.stride(idx).replaceAllUsesWith(stride); - } - } - } - - // Now we can remove the original arguments, as they should have no more - // users. - for (int i = 0; i < num_original_args; ++i) { - kernel.front().eraseArgument(0); - } - - return Status::OK(); -} - -StatusOr> TransformKernelToXlaThunk( - FuncOp func, const HloInstruction* const instr, ModuleOp kernel_module, - BufferAssignment* assignment) { - // Find the single LaunchFuncOp and compute a mapping from operands of - // the hlo instruction to the corresponding values of the kernel - // function in the target module; - LaunchFuncOp launchOp; - auto walkResult = func.walk([&launchOp](LaunchFuncOp op) { - if (launchOp) { - op.emitError("multiple kernels for single top-level HLO"); - return mlir::WalkResult::interrupt(); - } - launchOp = op; - return mlir::WalkResult::advance(); - }); - if (walkResult.wasInterrupted()) { - return InternalError("Multiple kernels for single top-level HLO"); - } - if (!launchOp) { - // If there was no launchOp, then no kernel was generated, so the lowering - // from the LHLO ops to the GPU dialect is not implemented yet. - return Unimplemented("No kernel was generated."); - } - - auto kernel = kernel_module.lookupSymbol(launchOp.kernel()); - - // Store the assignment of operands to block arguments. Note that an operand - // might be used in multiple argument positions, hence the vector. - OperandToValueMap operand_to_value_map; - TF_ASSIGN_OR_RETURN( - auto ordered_operands, - ComputeOperandToValueMap(&operand_to_value_map, instr, launchOp, kernel)); - - // Get the required buffers to support the inputs. Use a set and vector here - // to keep the order fixed. This is mostly useful for testing. - std::unordered_set buffers_needed; - std::vector buffers; - // TODO(b/137624192) Add support for tuples. - for (auto operand : ordered_operands) { - TF_ASSIGN_OR_RETURN(auto buffer, - assignment->GetUniqueTopLevelSlice(operand)); - if (buffers_needed.insert(buffer.allocation()).second) { - buffers.push_back(buffer.allocation()); - } - } - - // TODO(b/137624192) Add support for temp buffer. - // TODO(b/137624192) Add support for constant buffers. - - // Change the signature to match what the XLA runtime expects from the - // kernel. - TF_RETURN_IF_ERROR(InsertBufferLoadPreduleIntoKernel( - kernel, operand_to_value_map, ordered_operands, assignment, buffers)); - - // Finally, create the thunk and set the launch dimensions. - auto thunk = absl::make_unique( - buffers, kernel.getName().str(), instr, - /*unroll_factor=*/1); - - // Set launch bounds. - mlir::gpu::KernelDim3 block = launchOp.getBlockSizeOperandValues(); - mlir::gpu::KernelDim3 grid = launchOp.getGridSizeOperandValues(); - absl::optional num_threads = getLaunchBound(block); - absl::optional num_blocks = getLaunchBound(grid); - if (!num_threads || !num_blocks) { - return Unimplemented("Unsupported launch bounds"); - } - thunk->SetLaunchDimensions(gpu::LaunchDimensions(*num_blocks, *num_threads)); - return std::move(thunk); -} - -} // namespace - -StatusOr> MlirCompiler::RunBackend( - std::unique_ptr module, se::StreamExecutor* stream_exec, - se::DeviceMemoryAllocator* device_allocator) { - // Determine the HLO schedule, which is an ordering of HLO instructions. This - // is used by buffer assignment to enable buffer reuse, and the same ordering - // must also be used to determine the thunk launch schedule. - std::unique_ptr stream_assignment = - xla::gpu::AssignStreams(*module); - TF_ASSIGN_OR_RETURN( - std::unique_ptr hlo_schedule, - GpuHloSchedule::Build(*module, *stream_assignment, pointer_size_)); - - // Run buffer analysis on the HLO graph. This analysis figures out which - // temporary buffers are required to run the computation. - TF_ASSIGN_OR_RETURN(std::unique_ptr buffer_assignment, - BufferAssigner::Run( - module.get(), hlo_schedule->ConsumeHloOrdering(), - BufferSizeBytesFunction(), - /*color_alignment=*/ - [](LogicalBuffer::Color) { - return xla::gpu::kXlaAllocatedBufferAlignBytes; - }, - /*allocate_buffers_for_constants=*/true, - /*colorer=*/BufferAssigner::DefaultColorer(), - /*must_not_live_out=*/{}, &CanShareBufferHint)); - DumpHloModuleIfEnabled(*module, *buffer_assignment, "after_optimizations"); - - EmissionContext emission_context(std::move(module)); - if (error_handler_) { - emission_context.setErrorHandler(error_handler_); - } - - OwningModuleRef mlir_module = - ModuleOp::create(UnknownLoc::get(emission_context.getContext())); - LhloDialectEmitter lhlo_emitter(&emission_context, *buffer_assignment, - stream_exec->platform(), *mlir_module); - - TF_RETURN_IF_ERROR(lhlo_emitter.EmitComputation( - *emission_context.getHloModule()->entry_computation())); - - TF_RETURN_IF_ERROR( - module_hook_.invoke(IRHook::LoweringStage::LHLO, *mlir_module)); - - TF_RETURN_IF_ERROR(LowerLHLOToGPU(*mlir_module)); - - TF_RETURN_IF_ERROR( - module_hook_.invoke(IRHook::LoweringStage::GPU, *mlir_module)); - - TF_RETURN_IF_ERROR(LowerKernelBodiesToNVVM(*mlir_module)); - - TF_RETURN_IF_ERROR( - module_hook_.invoke(IRHook::LoweringStage::LLVM, *mlir_module)); - - TF_ASSIGN_OR_RETURN(OwningModuleRef kernel_module, - ExtractKernelModule(*mlir_module)); - - auto thunk_sequence = lhlo_emitter.ConsumeThunkSequence(); - for (auto entry : lhlo_emitter.InstructionToFunctionMap()) { - TF_ASSIGN_OR_RETURN( - auto thunk, - TransformKernelToXlaThunk(entry.second, entry.first, *kernel_module, - buffer_assignment.get())); - thunk_sequence->push_back(std::move(thunk)); - } - - TF_RETURN_IF_ERROR( - module_hook_.invoke(IRHook::LoweringStage::KERNEL, *kernel_module)); - - auto llvmModule = mlir::translateModuleToNVVMIR(*kernel_module); - - if (!llvmModule) { - return InternalError("Translation to LLVM failed"); - } - - llvmModule->setModuleIdentifier(emission_context.getHloModule()->name()); - // TODO(herhut): Why is this needed and does not come from the template? - llvmModule->setDataLayout(gpu::nvptx::kDataLayout); - - const auto& config = emission_context.getHloModule()->config(); - TF_ASSIGN_OR_RETURN( - auto ptx, xla::gpu::nvptx::CompileToPtx(llvmModule.get(), - GetGpuVersion(stream_exec), - config, GetLibdeviceDir(config))); - TF_ASSIGN_OR_RETURN( - auto cubin, se::CompileGpuAsm(stream_exec->device_ordinal(), ptx.c_str(), - gpu::PtxOptsFromConfig(config))); - - auto thunk_schedule = absl::make_unique( - std::move(thunk_sequence), std::move(stream_assignment), - hlo_schedule->ThunkLaunchOrder()); - - if (DumpingEnabledForHloModule(*emission_context.getHloModule())) { - DumpToFileInDirOrStdout(*emission_context.getHloModule(), "", - "thunk_schedule", thunk_schedule->ToString()); - } - - // TODO(b/137624192): Add profiling support. - return {absl::make_unique( - ptx, cubin, GetGpuVersion(stream_exec), std::move(thunk_schedule), - emission_context.releaseHloModule(), std::move(buffer_assignment), - nullptr, nullptr)}; -} - -StatusOr>> MlirCompiler::Compile( - std::unique_ptr module_group, - std::vector> stream_execs, - se::DeviceMemoryAllocator* device_allocator) { - return Unimplemented("Not yet implemented in MLIR compiler"); -} - -StatusOr>> -MlirCompiler::CompileAheadOfTime(std::unique_ptr module_group, - const AotCompilationOptions& options) { - return Unimplemented("Not yet implemented in MLIR compiler"); -} - void MlirCompiler::SetModuleHook(IRHook module_hook) { module_hook_ = module_hook; } @@ -579,14 +63,3 @@ void MlirCompiler::RemoveErrorHandler() { error_handler_ = nullptr; } } // namespace mlir_gpu } // namespace xla - -static bool InitModule() { - xla::Compiler::RegisterCompilerFactory( - stream_executor::cuda::kCudaPlatformId, []() { - return absl::make_unique( - absl::make_unique(), - absl::make_unique()); - }); - return true; -} -static bool module_initialized = InitModule(); diff --git a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h index 9aeef12ac28..a7b2f9446fa 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h +++ b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h @@ -16,7 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_MLIR_COMPILER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_MLIR_COMPILER_H_ -#include "absl/container/flat_hash_map.h" #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project #include "tensorflow/compiler/xla/service/compiler.h" @@ -27,7 +26,8 @@ namespace mlir_gpu { // A Compiler implementation that converts XLAs IR to a matching MLIR dialect, // performs all lowering on the MLIR IR and finally converts MLIR to LLVMIR for -// generation of a think suitable for XLAs runtime. +// generation of a thunk suitable for XLAs runtime. MlirCompilerImpl contains +// the implementation. class MlirCompiler : public Compiler { using ErrorHandler = std::function; @@ -37,30 +37,6 @@ class MlirCompiler : public Compiler { se::Platform::Id PlatformId() const override; - StatusOr> RunHloPasses( - std::unique_ptr module, se::StreamExecutor* stream_exec, - se::DeviceMemoryAllocator* device_allocator) override; - - StatusOr> RunBackend( - std::unique_ptr module, se::StreamExecutor* stream_exec, - se::DeviceMemoryAllocator* device_allocator) override; - - StatusOr>> Compile( - std::unique_ptr module_group, - std::vector> stream_execs, - se::DeviceMemoryAllocator* device_allocator) override; - - StatusOr>> - CompileAheadOfTime(std::unique_ptr module_group, - const AotCompilationOptions& options) override; - - HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override { - int64 pointer_size = pointer_size_; - return [pointer_size](const Shape& shape) { - return ShapeUtil::ByteSizeOf(shape, pointer_size); - }; - } - struct IRHook { enum class LoweringStage { LHLO, GPU, LLVM, KERNEL }; @@ -80,7 +56,7 @@ class MlirCompiler : public Compiler { void SetErrorHandler(ErrorHandler error_handler); void RemoveErrorHandler(); - private: + protected: ::mlir::MLIRContext context_; int64 pointer_size_; IRHook module_hook_; diff --git a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc new file mode 100644 index 00000000000..c258d532f8e --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc @@ -0,0 +1,584 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" // from @llvm-project +#include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Target/NVVMIR.h" // from @llvm-project +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/dump.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_constants.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_types.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" +#include "tensorflow/compiler/xla/service/gpu/nvptx_compiler.h" +#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" +#include "tensorflow/compiler/xla/service/gpu/target_constants.h" +#include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/mlir_gpu/emission_context.h" +#include "tensorflow/compiler/xla/service/mlir_gpu/failover_compiler.h" +#include "tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h" +#include "tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h" +#include "tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/cuda_libdevice_path.h" +#include "tensorflow/stream_executor/gpu/asm_compiler.h" + +namespace xla { +namespace mlir_gpu { +namespace { + +using ::mlir::BlockArgument; +using ::mlir::dyn_cast; +using ::mlir::FuncOp; +using ::mlir::ModuleOp; +using ::mlir::OwningModuleRef; +using ::mlir::UnknownLoc; +using ::mlir::Value; +using ::mlir::gpu::LaunchFuncOp; +using ::mlir::LLVM::LLVMDialect; +using ::mlir::LLVM::LLVMFuncOp; +using ::mlir::LLVM::LLVMType; +using ::xla::gpu::GpuExecutable; +using ::xla::gpu::GpuHloSchedule; +using ::xla::gpu::GpuVersion; +using ::xla::gpu::StreamAssignment; +using ::xla::gpu::ThunkSchedule; + +// A Compiler implementation that converts XLAs IR to a matching MLIR dialect, +// performs all lowering on the MLIR IR and finally converts MLIR to LLVMIR for +// generation of a thunk suitable for XLAs runtime. +class MlirCompilerImpl : public MlirCompiler { + public: + StatusOr> RunHloPasses( + std::unique_ptr module, se::StreamExecutor* stream_exec, + se::DeviceMemoryAllocator* device_allocator) override; + + StatusOr> RunBackend( + std::unique_ptr module, se::StreamExecutor* stream_exec, + se::DeviceMemoryAllocator* device_allocator) override; + + StatusOr>> Compile( + std::unique_ptr module_group, + std::vector> stream_execs, + se::DeviceMemoryAllocator* device_allocator) override; + + StatusOr>> + CompileAheadOfTime(std::unique_ptr module_group, + const AotCompilationOptions& options) override; + + HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override { + int64 pointer_size = pointer_size_; + return [pointer_size](const Shape& shape) { + return ShapeUtil::ByteSizeOf(shape, pointer_size); + }; + } +}; + +// TODO(b/137624192) Share with NVPTX compiler +static std::vector CandidateCudaRoots( + const HloModuleConfig& config) { + return tensorflow::CandidateCudaRoots( + config.debug_options().xla_gpu_cuda_data_dir()); +} + +void PrintCantFindCudaMessage(absl::string_view msg, + const HloModuleConfig& hlo_module_config) { + LOG(WARNING) << msg; + LOG(WARNING) << "Searched for CUDA in the following directories:"; + + for (const auto& dir : CandidateCudaRoots(hlo_module_config)) { + LOG(WARNING) << " " << dir; + } + LOG(WARNING) + << "You can choose the search directory by setting xla_gpu_cuda_data_dir " + "in HloModule's DebugOptions. For most apps, setting the environment " + "variable XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda will work."; +} + +// Returns the directory containing nvvm libdevice files. +std::string GetLibdeviceDir(const HloModuleConfig& hlo_module_config) { + for (const string& cuda_root : CandidateCudaRoots(hlo_module_config)) { + const std::string libdevice_dir = + tensorflow::io::JoinPath(cuda_root, "nvvm", "libdevice"); + VLOG(2) << "Looking for libdevice at " << libdevice_dir; + if (tensorflow::Env::Default()->IsDirectory(libdevice_dir).ok()) { + VLOG(2) << "Found libdevice dir " << libdevice_dir; + return libdevice_dir; + } + } + PrintCantFindCudaMessage( + "Can't find libdevice directory ${CUDA_DIR}/nvvm/libdevice. This may " + "result in compilation or runtime failures, if the program we try to run " + "uses routines from libdevice.", + hlo_module_config); + + // GetCudaRootCandidates always includes ".", but if everything fails, we + // return it anyway. Better than returning the empty string. + return "."; +} + +StatusOr> MlirCompilerImpl::RunHloPasses( + std::unique_ptr module, se::StreamExecutor* stream_exec, + se::DeviceMemoryAllocator* device_allocator) { + // Until we find a reason to do something different, run the same passes + // that the normal GPU backend runs. + gpu::NVPTXCompiler xla_compiler; + TF_RETURN_IF_ERROR(xla_compiler.OptimizeHloModule(module.get(), stream_exec, + device_allocator)); + TF_RETURN_IF_ERROR(xla_compiler.PrepareHloModuleForIrEmitting(module.get())); + + return std::move(module); +} + +// TODO(b/137624192): Move this to custom call handling and share. +absl::optional CanShareBufferHint(const HloInstruction* user, + const HloInstruction* operand, + const ShapeIndex& user_index) { + if (user->opcode() == HloOpcode::kCustomCall) { + // Share the bias buffer with the parent instruction. + if (user->custom_call_target() == xla::gpu::kGemmCallTarget) { + if (user->operand_count() == 3 && user->operand(2) == operand) { + return true; + } + } + // The operand of cholesky can be shared with the first output. + if (user->custom_call_target() == xla::gpu::kCusolverCholeskyCallTarget) { + return user_index.size() == 1 && user_index[0] == 0; + } + } + return absl::nullopt; +} + +// TODO(b/137624192): Share this with nvptx backend. +GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) { + int cc_major, cc_minor; + const auto& device_description = stream_exec->GetDeviceDescription(); + if (!device_description.cuda_compute_capability(&cc_major, &cc_minor)) { + LOG(WARNING) + << "Couldn't get compute capability for device; assuming sm_20."; + cc_major = 2; + cc_minor = 0; + } + return std::make_pair(cc_major, cc_minor); +} + +// Return the constant launch bound along the "x" dimension in "dim" if all the +// other dimensions are 1. Return nullopt otherwise or when any of the bounds +// is not constant. +static absl::optional getLaunchBound(const mlir::gpu::KernelDim3& dim) { + auto get_constant = [](mlir::Operation* op, + mlir::StringRef name) -> absl::optional { + if (auto constant = llvm::dyn_cast_or_null(op)) { + return constant.value().cast().getInt(); + } + op->emitError() << "bound " << name << " is not constant"; + return absl::nullopt; + }; + auto y_op = dim.y.getDefiningOp(); + auto dim_y = get_constant(y_op, "y"); + if (!dim_y.has_value() || dim_y.value() != 1) { + y_op->emitError() << "bound 'y' is not constant 1"; + return absl::nullopt; + } + auto z_op = dim.z.getDefiningOp(); + auto dim_z = get_constant(z_op, "z"); + if (!dim_z.has_value() || dim_z.value() != 1) { + z_op->emitError() << "bound 'z' is not constant 1"; + return absl::nullopt; + } + return get_constant(dim.x.getDefiningOp(), "x"); +} + +// Indexes of a range of arguments in a GPU function. This is used to keep the +// range of arguments that correspond to a lowered kernel argument of +// (previously) memref type. +struct LaunchFuncArgument { + int kernel_argument_begin; + int kernel_argument_size; +}; + +using OperandToValueMap = + absl::flat_hash_map>; + +static StatusOr> ComputeOperandToValueMap( + OperandToValueMap* operand_to_value_map, const HloInstruction* instr, + LaunchFuncOp launchOp, LLVMFuncOp kernel) { + auto operands = instr->operands(); + std::vector ordered_operands; + bool has_failed = false; + // A memref will expand into multiple kernel operands, accumulate their number + // in order to find them later. + int cur_operand_position = 0; + + for (int kernel_index = 0; kernel_index < launchOp.getNumKernelOperands(); + ++kernel_index) { + auto launchop_operand = + launchOp.getKernelOperand(kernel_index).dyn_cast(); + if (!launchop_operand) { + launchOp.emitError("argument to kernel is not a function input"); + has_failed = true; + continue; + } + auto memref_type = + launchop_operand.getType().dyn_cast<::mlir::MemRefType>(); + if (!memref_type) { + launchOp.emitError("only memref-typed arguments are supported"); + has_failed = true; + break; + } + // host_index is the argument position to the surrounding function that + // contains the launch. This index corresponds to HLO operand indices + // by construction. + auto host_index = launchop_operand.getArgNumber(); + // The trailing argument to the outer function are the results. + auto operand = + (host_index < operands.size()) ? operands[host_index] : instr; + if (!operand_to_value_map->count(operand)) { + ordered_operands.push_back(operand); + } + // Associate the HLO operand with the argument values of the kernel + // function. + int num_unpacked = + mlir::MemRefDescriptor::getNumUnpackedValues(memref_type); + (*operand_to_value_map)[operand].push_back( + {cur_operand_position, num_unpacked}); + cur_operand_position += num_unpacked; + } + if (has_failed) { + return InternalError("Mapping operands to kernel arguments has failed."); + } + return ordered_operands; +} + +Status InsertBufferLoadPreduleIntoKernel( + LLVMFuncOp kernel, const OperandToValueMap& operand_to_value_map, + const std::vector& ordered_operands, + BufferAssignment* assignment, + const std::vector& buffers) { + mlir::OpBuilder builder(kernel.getBody()); + auto llvm_dialect = kernel.getContext()->getRegisteredDialect(); + auto offset_type = LLVMType::getInt64Ty(llvm_dialect); + auto ptr_type = LLVMType::getInt8PtrTy(llvm_dialect); + auto void_type = LLVMType::getVoidTy(llvm_dialect); + auto loc = kernel.getLoc(); + + auto num_original_args = kernel.getNumArguments(); + std::vector new_arg_types(buffers.size(), ptr_type); + kernel.setAttr(kernel.getTypeAttrName(), + mlir::TypeAttr::get(LLVMType::getFunctionTy( + void_type, new_arg_types, /*isVarArg=*/false))); + std::vector original_args(kernel.args_begin(), kernel.args_end()); + + std::vector as_mlir_types(new_arg_types.begin(), + new_arg_types.end()); + auto new_args = kernel.front().addArguments(as_mlir_types); + std::vector buffer_args(new_args.begin(), new_args.end()); + + for (auto operand : ordered_operands) { + TF_ASSIGN_OR_RETURN(auto slice, + assignment->GetUniqueTopLevelSlice(operand)); + auto buffer = std::find(buffers.begin(), buffers.end(), slice.allocation()); + auto index = buffer - buffers.begin(); + auto offset = builder.create( + loc, offset_type, builder.getI64IntegerAttr(slice.offset())); + auto ptr = buffer_args[index]; + + // Replace uses of function arguments pertaining to memref descriptors with + // values derived from HLO buffers. The instructions inserting these values + // into memref descriptors were already introduced during the lowering phase + // as per MLIR calling convention. + for (auto arg : operand_to_value_map.at(operand)) { + mlir::MemRefDescriptorView original( + mlir::ValueRange(original_args) + .slice(arg.kernel_argument_begin, arg.kernel_argument_size)); + + // Allocated and aligned pointers are the same. + auto casted = builder.create( + loc, original.alignedPtr().getType().cast(), + mlir::ValueRange(ptr)); + original.alignedPtr().replaceAllUsesWith(casted); + original.allocatedPtr().replaceAllUsesWith(casted); + + // Use the offset of the HLO buffer instead of the one expected in the + // function call. + original.offset().replaceAllUsesWith(offset); + + // Fill the shape. + auto shape = operand->shape(); + // Unless the operand is a scalar pointer, also fill shape and strides. + if (shape.dimensions().empty()) { + continue; + } + + // TODO(b/137624192) Pass in the descriptor to allow for dynamic shapes. + assert(shape.IsArray() && shape.is_static()); + for (auto extent : llvm::enumerate(shape.dimensions())) { + auto shape = builder.create( + loc, original.size(extent.index()).getType(), + builder.getI64IntegerAttr(extent.value())); + original.size(extent.index()).replaceAllUsesWith(shape); + } + // Finally, fill the strides. + // TODO(b/137624192): Take assigned layout into account. + uint64_t accumulator = 0; + for (int64_t idx = shape.rank() - 1; idx >= 0; --idx) { + if (accumulator == 0) { + accumulator = 1; + } else { + accumulator *= shape.dimensions(idx + 1); + } + auto stride = builder.create( + loc, original.stride(idx).getType(), + builder.getI64IntegerAttr(accumulator)); + original.stride(idx).replaceAllUsesWith(stride); + } + } + } + + // Now we can remove the original arguments, as they should have no more + // users. + for (int i = 0; i < num_original_args; ++i) { + kernel.front().eraseArgument(0); + } + + return Status::OK(); +} + +StatusOr> TransformKernelToXlaThunk( + FuncOp func, const HloInstruction* const instr, ModuleOp kernel_module, + BufferAssignment* assignment) { + // Find the single LaunchFuncOp and compute a mapping from operands of + // the hlo instruction to the corresponding values of the kernel + // function in the target module; + LaunchFuncOp launchOp; + auto walkResult = func.walk([&launchOp](LaunchFuncOp op) { + if (launchOp) { + op.emitError("multiple kernels for single top-level HLO"); + return mlir::WalkResult::interrupt(); + } + launchOp = op; + return mlir::WalkResult::advance(); + }); + if (walkResult.wasInterrupted()) { + return InternalError("Multiple kernels for single top-level HLO"); + } + if (!launchOp) { + // If there was no launchOp, then no kernel was generated, so the lowering + // from the LHLO ops to the GPU dialect is not implemented yet. + return Unimplemented("No kernel was generated."); + } + + auto kernel = kernel_module.lookupSymbol(launchOp.kernel()); + + // Store the assignment of operands to block arguments. Note that an operand + // might be used in multiple argument positions, hence the vector. + OperandToValueMap operand_to_value_map; + TF_ASSIGN_OR_RETURN( + auto ordered_operands, + ComputeOperandToValueMap(&operand_to_value_map, instr, launchOp, kernel)); + + // Get the required buffers to support the inputs. Use a set and vector here + // to keep the order fixed. This is mostly useful for testing. + std::unordered_set buffers_needed; + std::vector buffers; + // TODO(b/137624192) Add support for tuples. + for (auto operand : ordered_operands) { + TF_ASSIGN_OR_RETURN(auto buffer, + assignment->GetUniqueTopLevelSlice(operand)); + if (buffers_needed.insert(buffer.allocation()).second) { + buffers.push_back(buffer.allocation()); + } + } + + // TODO(b/137624192) Add support for temp buffer. + // TODO(b/137624192) Add support for constant buffers. + + // Change the signature to match what the XLA runtime expects from the + // kernel. + TF_RETURN_IF_ERROR(InsertBufferLoadPreduleIntoKernel( + kernel, operand_to_value_map, ordered_operands, assignment, buffers)); + + // Finally, create the thunk and set the launch dimensions. + auto thunk = absl::make_unique( + buffers, kernel.getName().str(), instr, + /*unroll_factor=*/1); + + // Set launch bounds. + mlir::gpu::KernelDim3 block = launchOp.getBlockSizeOperandValues(); + mlir::gpu::KernelDim3 grid = launchOp.getGridSizeOperandValues(); + absl::optional num_threads = getLaunchBound(block); + absl::optional num_blocks = getLaunchBound(grid); + if (!num_threads || !num_blocks) { + return Unimplemented("Unsupported launch bounds"); + } + thunk->SetLaunchDimensions(gpu::LaunchDimensions(*num_blocks, *num_threads)); + return std::move(thunk); +} + +StatusOr> MlirCompilerImpl::RunBackend( + std::unique_ptr module, se::StreamExecutor* stream_exec, + se::DeviceMemoryAllocator* device_allocator) { + // Determine the HLO schedule, which is an ordering of HLO instructions. This + // is used by buffer assignment to enable buffer reuse, and the same ordering + // must also be used to determine the thunk launch schedule. + std::unique_ptr stream_assignment = + xla::gpu::AssignStreams(*module); + TF_ASSIGN_OR_RETURN( + std::unique_ptr hlo_schedule, + GpuHloSchedule::Build(*module, *stream_assignment, pointer_size_)); + + // Run buffer analysis on the HLO graph. This analysis figures out which + // temporary buffers are required to run the computation. + TF_ASSIGN_OR_RETURN(std::unique_ptr buffer_assignment, + BufferAssigner::Run( + module.get(), hlo_schedule->ConsumeHloOrdering(), + BufferSizeBytesFunction(), + /*color_alignment=*/ + [](LogicalBuffer::Color) { + return xla::gpu::kXlaAllocatedBufferAlignBytes; + }, + /*allocate_buffers_for_constants=*/true, + /*colorer=*/BufferAssigner::DefaultColorer(), + /*must_not_live_out=*/{}, &CanShareBufferHint)); + DumpHloModuleIfEnabled(*module, *buffer_assignment, "after_optimizations"); + + EmissionContext emission_context(std::move(module)); + if (error_handler_) { + emission_context.setErrorHandler(error_handler_); + } + + OwningModuleRef mlir_module = + ModuleOp::create(UnknownLoc::get(emission_context.getContext())); + LhloDialectEmitter lhlo_emitter(&emission_context, *buffer_assignment, + stream_exec->platform(), *mlir_module); + + TF_RETURN_IF_ERROR(lhlo_emitter.EmitComputation( + *emission_context.getHloModule()->entry_computation())); + + TF_RETURN_IF_ERROR( + module_hook_.invoke(IRHook::LoweringStage::LHLO, *mlir_module)); + + TF_RETURN_IF_ERROR(LowerLHLOToGPU(*mlir_module)); + + TF_RETURN_IF_ERROR( + module_hook_.invoke(IRHook::LoweringStage::GPU, *mlir_module)); + + TF_RETURN_IF_ERROR(LowerKernelBodiesToNVVM(*mlir_module)); + + TF_RETURN_IF_ERROR( + module_hook_.invoke(IRHook::LoweringStage::LLVM, *mlir_module)); + + TF_ASSIGN_OR_RETURN(OwningModuleRef kernel_module, + ExtractKernelModule(*mlir_module)); + + auto thunk_sequence = lhlo_emitter.ConsumeThunkSequence(); + for (auto entry : lhlo_emitter.InstructionToFunctionMap()) { + TF_ASSIGN_OR_RETURN( + auto thunk, + TransformKernelToXlaThunk(entry.second, entry.first, *kernel_module, + buffer_assignment.get())); + thunk_sequence->push_back(std::move(thunk)); + } + + TF_RETURN_IF_ERROR( + module_hook_.invoke(IRHook::LoweringStage::KERNEL, *kernel_module)); + + auto llvmModule = mlir::translateModuleToNVVMIR(*kernel_module); + + if (!llvmModule) { + return InternalError("Translation to LLVM failed"); + } + + llvmModule->setModuleIdentifier(emission_context.getHloModule()->name()); + // TODO(herhut): Why is this needed and does not come from the template? + llvmModule->setDataLayout(gpu::nvptx::kDataLayout); + + const auto& config = emission_context.getHloModule()->config(); + TF_ASSIGN_OR_RETURN( + auto ptx, xla::gpu::nvptx::CompileToPtx(llvmModule.get(), + GetGpuVersion(stream_exec), + config, GetLibdeviceDir(config))); + TF_ASSIGN_OR_RETURN( + auto cubin, se::CompileGpuAsm(stream_exec->device_ordinal(), ptx.c_str(), + gpu::PtxOptsFromConfig(config))); + + auto thunk_schedule = absl::make_unique( + std::move(thunk_sequence), std::move(stream_assignment), + hlo_schedule->ThunkLaunchOrder()); + + if (DumpingEnabledForHloModule(*emission_context.getHloModule())) { + DumpToFileInDirOrStdout(*emission_context.getHloModule(), "", + "thunk_schedule", thunk_schedule->ToString()); + } + + // TODO(b/137624192): Add profiling support. + return {absl::make_unique( + ptx, cubin, GetGpuVersion(stream_exec), std::move(thunk_schedule), + emission_context.releaseHloModule(), std::move(buffer_assignment), + nullptr, nullptr)}; +} + +StatusOr>> MlirCompilerImpl::Compile( + std::unique_ptr module_group, + std::vector> stream_execs, + se::DeviceMemoryAllocator* device_allocator) { + return Unimplemented("Not yet implemented in MLIR compiler"); +} + +StatusOr>> +MlirCompilerImpl::CompileAheadOfTime( + std::unique_ptr /*module_group*/, + const AotCompilationOptions& /*options*/) { + return Unimplemented("Not yet implemented in MLIR compiler"); +} + +} // namespace +} // namespace mlir_gpu +} // namespace xla + +static bool InitModule() { + xla::Compiler::RegisterCompilerFactory( + stream_executor::cuda::kCudaPlatformId, []() { + return absl::make_unique( + absl::make_unique(), + absl::make_unique()); + }); + return true; +} +static bool module_initialized = InitModule();