Split MlirCompiler into two classes.
The base class is used as a way to interface with the MlirCompiler. The implementation class hides the cuda specific dependencies. With this change, we don't need to put xla-opt-main into a if_cuda_is_configured section. PiperOrigin-RevId: 308001425 Change-Id: Id59fe7ceaed5c03efef96608cdeb8e455a257a26
This commit is contained in:
parent
cf9d79b432
commit
6bc075083b
@ -977,7 +977,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service/gpu:gpu_transfer_manager",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
] + if_cuda_is_configured([
|
||||
"//tensorflow/compiler/xla/service/mlir_gpu:mlir_compiler",
|
||||
"//tensorflow/compiler/xla/service/mlir_gpu:mlir_compiler_impl",
|
||||
]),
|
||||
)
|
||||
|
||||
|
||||
@ -59,11 +59,26 @@ cc_library(
|
||||
|
||||
cc_library(
|
||||
name = "mlir_compiler",
|
||||
srcs = if_cuda_is_configured(["mlir_compiler.cc"]),
|
||||
hdrs = if_cuda_is_configured(["mlir_compiler.h"]),
|
||||
deps = if_cuda_is_configured([
|
||||
srcs = ["mlir_compiler.cc"],
|
||||
hdrs = ["mlir_compiler.h"],
|
||||
deps = [
|
||||
":emission_context",
|
||||
"//tensorflow/compiler/xla/service:compiler",
|
||||
"//tensorflow/compiler/xla/service/gpu:target_constants",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"@llvm-project//llvm:core",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:LLVMDialect",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mlir_compiler_impl",
|
||||
srcs = if_cuda_is_configured(["mlir_compiler_impl.cc"]),
|
||||
deps = if_cuda_is_configured([
|
||||
":mlir_compiler",
|
||||
":failover_compiler",
|
||||
":emission_context",
|
||||
":kernel_lowering",
|
||||
":lhlo_dialect_emitter",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
@ -77,7 +92,6 @@ cc_library(
|
||||
"@llvm-project//mlir:TargetNVVMIR",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/service:buffer_assignment",
|
||||
"//tensorflow/compiler/xla/service:compiler",
|
||||
"//tensorflow/compiler/xla/service:dump",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service/gpu:gpu_constants",
|
||||
@ -93,7 +107,6 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend",
|
||||
"//tensorflow/core:cuda_libdevice_path",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/stream_executor:stream_executor_headers",
|
||||
"//tensorflow/stream_executor/gpu:asm_compiler",
|
||||
]),
|
||||
alwayslink = True, # Contains compiler registration
|
||||
@ -186,8 +199,8 @@ cc_library(
|
||||
cc_library(
|
||||
name = "xla_gpu_opt_lib",
|
||||
testonly = True,
|
||||
srcs = if_cuda_is_configured(["xla_gpu_opt.cc"]),
|
||||
hdrs = if_cuda_is_configured(["xla_gpu_opt.h"]),
|
||||
srcs = ["xla_gpu_opt.cc"],
|
||||
hdrs = ["xla_gpu_opt.h"],
|
||||
tags = ["no_pip"],
|
||||
deps = [
|
||||
":failover_compiler",
|
||||
@ -212,7 +225,7 @@ cc_library(
|
||||
tf_cc_binary(
|
||||
name = "xla-gpu-opt",
|
||||
testonly = True,
|
||||
srcs = if_cuda_is_configured(["xla_gpu_opt_main.cc"]),
|
||||
srcs = ["xla_gpu_opt_main.cc"],
|
||||
tags = ["no_pip"],
|
||||
deps = [
|
||||
":mlir_compiler",
|
||||
|
||||
@ -17,69 +17,18 @@ limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" // from @llvm-project
|
||||
#include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/Location.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/OperationSupport.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "mlir/Target/NVVMIR.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/dump.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_types.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/nvptx_compiler.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/stream_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/target_constants.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/mlir_gpu/emission_context.h"
|
||||
#include "tensorflow/compiler/xla/service/mlir_gpu/failover_compiler.h"
|
||||
#include "tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h"
|
||||
#include "tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/platform/cuda_libdevice_path.h"
|
||||
#include "tensorflow/stream_executor/gpu/asm_compiler.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
|
||||
namespace xla {
|
||||
namespace mlir_gpu {
|
||||
namespace {
|
||||
|
||||
using ::mlir::BlockArgument;
|
||||
using ::mlir::dyn_cast;
|
||||
using ::mlir::FuncOp;
|
||||
using ::mlir::MLIRContext;
|
||||
using ::mlir::ModuleOp;
|
||||
using ::mlir::OwningModuleRef;
|
||||
using ::mlir::UnknownLoc;
|
||||
using ::mlir::Value;
|
||||
using ::mlir::gpu::LaunchFuncOp;
|
||||
using ::mlir::LLVM::LLVMDialect;
|
||||
using ::mlir::LLVM::LLVMFuncOp;
|
||||
using ::mlir::LLVM::LLVMType;
|
||||
using ::xla::gpu::GpuExecutable;
|
||||
using ::xla::gpu::GpuHloSchedule;
|
||||
using ::xla::gpu::GpuVersion;
|
||||
using ::xla::gpu::StreamAssignment;
|
||||
using ::xla::gpu::ThunkSchedule;
|
||||
|
||||
int64 ConfigureLLVMModuleAndGetPointerSize(MLIRContext* context) {
|
||||
LLVMDialect* dialect = context->getRegisteredDialect<LLVMDialect>();
|
||||
@ -89,49 +38,6 @@ int64 ConfigureLLVMModuleAndGetPointerSize(MLIRContext* context) {
|
||||
return module.getDataLayout().getPointerSize();
|
||||
}
|
||||
|
||||
// TODO(b/137624192) Share with NVPTX compiler
|
||||
static std::vector<std::string> CandidateCudaRoots(
|
||||
const HloModuleConfig& config) {
|
||||
return tensorflow::CandidateCudaRoots(
|
||||
config.debug_options().xla_gpu_cuda_data_dir());
|
||||
}
|
||||
|
||||
void PrintCantFindCudaMessage(absl::string_view msg,
|
||||
const HloModuleConfig& hlo_module_config) {
|
||||
LOG(WARNING) << msg;
|
||||
LOG(WARNING) << "Searched for CUDA in the following directories:";
|
||||
|
||||
for (const auto& dir : CandidateCudaRoots(hlo_module_config)) {
|
||||
LOG(WARNING) << " " << dir;
|
||||
}
|
||||
LOG(WARNING)
|
||||
<< "You can choose the search directory by setting xla_gpu_cuda_data_dir "
|
||||
"in HloModule's DebugOptions. For most apps, setting the environment "
|
||||
"variable XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda will work.";
|
||||
}
|
||||
|
||||
// Returns the directory containing nvvm libdevice files.
|
||||
string GetLibdeviceDir(const HloModuleConfig& hlo_module_config) {
|
||||
for (const string& cuda_root : CandidateCudaRoots(hlo_module_config)) {
|
||||
const string libdevice_dir =
|
||||
tensorflow::io::JoinPath(cuda_root, "nvvm", "libdevice");
|
||||
VLOG(2) << "Looking for libdevice at " << libdevice_dir;
|
||||
if (tensorflow::Env::Default()->IsDirectory(libdevice_dir).ok()) {
|
||||
VLOG(2) << "Found libdevice dir " << libdevice_dir;
|
||||
return libdevice_dir;
|
||||
}
|
||||
}
|
||||
PrintCantFindCudaMessage(
|
||||
"Can't find libdevice directory ${CUDA_DIR}/nvvm/libdevice. This may "
|
||||
"result in compilation or runtime failures, if the program we try to run "
|
||||
"uses routines from libdevice.",
|
||||
hlo_module_config);
|
||||
|
||||
// GetCudaRootCandidates always includes ".", but if everything fails, we
|
||||
// return it anyway. Better than returning the empty string.
|
||||
return ".";
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
MlirCompiler::MlirCompiler()
|
||||
@ -141,428 +47,6 @@ se::Platform::Id MlirCompiler::PlatformId() const {
|
||||
return stream_executor::cuda::kCudaPlatformId;
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<HloModule>> MlirCompiler::RunHloPasses(
|
||||
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
||||
se::DeviceMemoryAllocator* device_allocator) {
|
||||
// Until we find a reason to do something different, run the same passes
|
||||
// that the normal GPU backend runs.
|
||||
gpu::NVPTXCompiler xla_compiler;
|
||||
TF_RETURN_IF_ERROR(xla_compiler.OptimizeHloModule(module.get(), stream_exec,
|
||||
device_allocator));
|
||||
TF_RETURN_IF_ERROR(xla_compiler.PrepareHloModuleForIrEmitting(module.get()));
|
||||
|
||||
return std::move(module);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// TODO(b/137624192): Move this to custom call handling and share.
|
||||
absl::optional<bool> CanShareBufferHint(const HloInstruction* user,
|
||||
const HloInstruction* operand,
|
||||
const ShapeIndex& user_index) {
|
||||
if (user->opcode() == HloOpcode::kCustomCall) {
|
||||
// Share the bias buffer with the parent instruction.
|
||||
if (user->custom_call_target() == xla::gpu::kGemmCallTarget) {
|
||||
if (user->operand_count() == 3 && user->operand(2) == operand) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
// The operand of cholesky can be shared with the first output.
|
||||
if (user->custom_call_target() == xla::gpu::kCusolverCholeskyCallTarget) {
|
||||
return user_index.size() == 1 && user_index[0] == 0;
|
||||
}
|
||||
}
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
// TODO(b/137624192): Share this with nvptx backend.
|
||||
GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) {
|
||||
int cc_major, cc_minor;
|
||||
const auto& device_description = stream_exec->GetDeviceDescription();
|
||||
if (!device_description.cuda_compute_capability(&cc_major, &cc_minor)) {
|
||||
LOG(WARNING)
|
||||
<< "Couldn't get compute capability for device; assuming sm_20.";
|
||||
cc_major = 2;
|
||||
cc_minor = 0;
|
||||
}
|
||||
return std::make_pair(cc_major, cc_minor);
|
||||
}
|
||||
|
||||
// Return the constant launch bound along the "x" dimension in "dim" if all the
|
||||
// other dimensions are 1. Return nullopt otherwise or when any of the bounds
|
||||
// is not constant.
|
||||
static absl::optional<int64> getLaunchBound(const mlir::gpu::KernelDim3& dim) {
|
||||
auto get_constant = [](mlir::Operation* op,
|
||||
mlir::StringRef name) -> absl::optional<int64> {
|
||||
if (auto constant = llvm::dyn_cast_or_null<mlir::ConstantOp>(op)) {
|
||||
return constant.value().cast<mlir::IntegerAttr>().getInt();
|
||||
}
|
||||
op->emitError() << "bound " << name << " is not constant";
|
||||
return absl::nullopt;
|
||||
};
|
||||
auto y_op = dim.y.getDefiningOp();
|
||||
auto dim_y = get_constant(y_op, "y");
|
||||
if (!dim_y.has_value() || dim_y.value() != 1) {
|
||||
y_op->emitError() << "bound 'y' is not constant 1";
|
||||
return absl::nullopt;
|
||||
}
|
||||
auto z_op = dim.z.getDefiningOp();
|
||||
auto dim_z = get_constant(z_op, "z");
|
||||
if (!dim_z.has_value() || dim_z.value() != 1) {
|
||||
z_op->emitError() << "bound 'z' is not constant 1";
|
||||
return absl::nullopt;
|
||||
}
|
||||
return get_constant(dim.x.getDefiningOp(), "x");
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Indexes of a range of arguments in a GPU function. This is used to keep the
|
||||
// range of arguments that correspond to a lowered kernel argument of
|
||||
// (previously) memref type.
|
||||
struct LaunchFuncArgument {
|
||||
int kernel_argument_begin;
|
||||
int kernel_argument_size;
|
||||
};
|
||||
|
||||
} // end namespace
|
||||
|
||||
using OperandToValueMap =
|
||||
absl::flat_hash_map<const HloInstruction*, std::vector<LaunchFuncArgument>>;
|
||||
|
||||
static StatusOr<std::vector<const HloInstruction*>> ComputeOperandToValueMap(
|
||||
OperandToValueMap* operand_to_value_map, const HloInstruction* instr,
|
||||
LaunchFuncOp launchOp, LLVMFuncOp kernel) {
|
||||
auto operands = instr->operands();
|
||||
std::vector<const HloInstruction*> ordered_operands;
|
||||
bool has_failed = false;
|
||||
// A memref will expand into multiple kernel operands, accumulate their number
|
||||
// in order to find them later.
|
||||
int cur_operand_position = 0;
|
||||
|
||||
for (int kernel_index = 0; kernel_index < launchOp.getNumKernelOperands();
|
||||
++kernel_index) {
|
||||
auto launchop_operand =
|
||||
launchOp.getKernelOperand(kernel_index).dyn_cast<BlockArgument>();
|
||||
if (!launchop_operand) {
|
||||
launchOp.emitError("argument to kernel is not a function input");
|
||||
has_failed = true;
|
||||
continue;
|
||||
}
|
||||
auto memref_type =
|
||||
launchop_operand.getType().dyn_cast<::mlir::MemRefType>();
|
||||
if (!memref_type) {
|
||||
launchOp.emitError("only memref-typed arguments are supported");
|
||||
has_failed = true;
|
||||
break;
|
||||
}
|
||||
// host_index is the argument position to the surrounding function that
|
||||
// contains the launch. This index corresponds to HLO operand indices
|
||||
// by construction.
|
||||
auto host_index = launchop_operand.getArgNumber();
|
||||
// The trailing argument to the outer function are the results.
|
||||
auto operand =
|
||||
(host_index < operands.size()) ? operands[host_index] : instr;
|
||||
if (!operand_to_value_map->count(operand)) {
|
||||
ordered_operands.push_back(operand);
|
||||
}
|
||||
// Associate the HLO operand with the argument values of the kernel
|
||||
// function.
|
||||
int num_unpacked =
|
||||
mlir::MemRefDescriptor::getNumUnpackedValues(memref_type);
|
||||
(*operand_to_value_map)[operand].push_back(
|
||||
{cur_operand_position, num_unpacked});
|
||||
cur_operand_position += num_unpacked;
|
||||
}
|
||||
if (has_failed) {
|
||||
return InternalError("Mapping operands to kernel arguments has failed.");
|
||||
}
|
||||
return ordered_operands;
|
||||
}
|
||||
|
||||
Status InsertBufferLoadPreduleIntoKernel(
|
||||
LLVMFuncOp kernel, const OperandToValueMap& operand_to_value_map,
|
||||
const std::vector<const HloInstruction*>& ordered_operands,
|
||||
BufferAssignment* assignment,
|
||||
const std::vector<const BufferAllocation*>& buffers) {
|
||||
mlir::OpBuilder builder(kernel.getBody());
|
||||
auto llvm_dialect = kernel.getContext()->getRegisteredDialect<LLVMDialect>();
|
||||
auto offset_type = LLVMType::getInt64Ty(llvm_dialect);
|
||||
auto ptr_type = LLVMType::getInt8PtrTy(llvm_dialect);
|
||||
auto void_type = LLVMType::getVoidTy(llvm_dialect);
|
||||
auto loc = kernel.getLoc();
|
||||
|
||||
auto num_original_args = kernel.getNumArguments();
|
||||
std::vector<LLVMType> new_arg_types(buffers.size(), ptr_type);
|
||||
kernel.setAttr(kernel.getTypeAttrName(),
|
||||
mlir::TypeAttr::get(LLVMType::getFunctionTy(
|
||||
void_type, new_arg_types, /*isVarArg=*/false)));
|
||||
std::vector<Value> original_args(kernel.args_begin(), kernel.args_end());
|
||||
|
||||
std::vector<mlir::Type> as_mlir_types(new_arg_types.begin(),
|
||||
new_arg_types.end());
|
||||
auto new_args = kernel.front().addArguments(as_mlir_types);
|
||||
std::vector<Value> buffer_args(new_args.begin(), new_args.end());
|
||||
|
||||
for (auto operand : ordered_operands) {
|
||||
TF_ASSIGN_OR_RETURN(auto slice,
|
||||
assignment->GetUniqueTopLevelSlice(operand));
|
||||
auto buffer = std::find(buffers.begin(), buffers.end(), slice.allocation());
|
||||
auto index = buffer - buffers.begin();
|
||||
auto offset = builder.create<mlir::LLVM::ConstantOp>(
|
||||
loc, offset_type, builder.getI64IntegerAttr(slice.offset()));
|
||||
auto ptr = buffer_args[index];
|
||||
|
||||
// Replace uses of function arguments pertaining to memref descriptors with
|
||||
// values derived from HLO buffers. The instructions inserting these values
|
||||
// into memref descriptors were already introduced during the lowering phase
|
||||
// as per MLIR calling convention.
|
||||
for (auto arg : operand_to_value_map.at(operand)) {
|
||||
mlir::MemRefDescriptorView original(
|
||||
mlir::ValueRange(original_args)
|
||||
.slice(arg.kernel_argument_begin, arg.kernel_argument_size));
|
||||
|
||||
// Allocated and aligned pointers are the same.
|
||||
auto casted = builder.create<mlir::LLVM::BitcastOp>(
|
||||
loc, original.alignedPtr().getType().cast<LLVMType>(),
|
||||
mlir::ValueRange(ptr));
|
||||
original.alignedPtr().replaceAllUsesWith(casted);
|
||||
original.allocatedPtr().replaceAllUsesWith(casted);
|
||||
|
||||
// Use the offset of the HLO buffer instead of the one expected in the
|
||||
// function call.
|
||||
original.offset().replaceAllUsesWith(offset);
|
||||
|
||||
// Fill the shape.
|
||||
auto shape = operand->shape();
|
||||
// Unless the operand is a scalar pointer, also fill shape and strides.
|
||||
if (shape.dimensions().empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// TODO(b/137624192) Pass in the descriptor to allow for dynamic shapes.
|
||||
assert(shape.IsArray() && shape.is_static());
|
||||
for (auto extent : llvm::enumerate(shape.dimensions())) {
|
||||
auto shape = builder.create<mlir::LLVM::ConstantOp>(
|
||||
loc, original.size(extent.index()).getType(),
|
||||
builder.getI64IntegerAttr(extent.value()));
|
||||
original.size(extent.index()).replaceAllUsesWith(shape);
|
||||
}
|
||||
// Finally, fill the strides.
|
||||
// TODO(b/137624192): Take assigned layout into account.
|
||||
uint64_t accumulator = 0;
|
||||
for (int64_t idx = shape.rank() - 1; idx >= 0; --idx) {
|
||||
if (accumulator == 0) {
|
||||
accumulator = 1;
|
||||
} else {
|
||||
accumulator *= shape.dimensions(idx + 1);
|
||||
}
|
||||
auto stride = builder.create<mlir::LLVM::ConstantOp>(
|
||||
loc, original.stride(idx).getType(),
|
||||
builder.getI64IntegerAttr(accumulator));
|
||||
original.stride(idx).replaceAllUsesWith(stride);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Now we can remove the original arguments, as they should have no more
|
||||
// users.
|
||||
for (int i = 0; i < num_original_args; ++i) {
|
||||
kernel.front().eraseArgument(0);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<gpu::KernelThunk>> TransformKernelToXlaThunk(
|
||||
FuncOp func, const HloInstruction* const instr, ModuleOp kernel_module,
|
||||
BufferAssignment* assignment) {
|
||||
// Find the single LaunchFuncOp and compute a mapping from operands of
|
||||
// the hlo instruction to the corresponding values of the kernel
|
||||
// function in the target module;
|
||||
LaunchFuncOp launchOp;
|
||||
auto walkResult = func.walk([&launchOp](LaunchFuncOp op) {
|
||||
if (launchOp) {
|
||||
op.emitError("multiple kernels for single top-level HLO");
|
||||
return mlir::WalkResult::interrupt();
|
||||
}
|
||||
launchOp = op;
|
||||
return mlir::WalkResult::advance();
|
||||
});
|
||||
if (walkResult.wasInterrupted()) {
|
||||
return InternalError("Multiple kernels for single top-level HLO");
|
||||
}
|
||||
if (!launchOp) {
|
||||
// If there was no launchOp, then no kernel was generated, so the lowering
|
||||
// from the LHLO ops to the GPU dialect is not implemented yet.
|
||||
return Unimplemented("No kernel was generated.");
|
||||
}
|
||||
|
||||
auto kernel = kernel_module.lookupSymbol<LLVMFuncOp>(launchOp.kernel());
|
||||
|
||||
// Store the assignment of operands to block arguments. Note that an operand
|
||||
// might be used in multiple argument positions, hence the vector.
|
||||
OperandToValueMap operand_to_value_map;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto ordered_operands,
|
||||
ComputeOperandToValueMap(&operand_to_value_map, instr, launchOp, kernel));
|
||||
|
||||
// Get the required buffers to support the inputs. Use a set and vector here
|
||||
// to keep the order fixed. This is mostly useful for testing.
|
||||
std::unordered_set<const BufferAllocation*> buffers_needed;
|
||||
std::vector<const BufferAllocation*> buffers;
|
||||
// TODO(b/137624192) Add support for tuples.
|
||||
for (auto operand : ordered_operands) {
|
||||
TF_ASSIGN_OR_RETURN(auto buffer,
|
||||
assignment->GetUniqueTopLevelSlice(operand));
|
||||
if (buffers_needed.insert(buffer.allocation()).second) {
|
||||
buffers.push_back(buffer.allocation());
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(b/137624192) Add support for temp buffer.
|
||||
// TODO(b/137624192) Add support for constant buffers.
|
||||
|
||||
// Change the signature to match what the XLA runtime expects from the
|
||||
// kernel.
|
||||
TF_RETURN_IF_ERROR(InsertBufferLoadPreduleIntoKernel(
|
||||
kernel, operand_to_value_map, ordered_operands, assignment, buffers));
|
||||
|
||||
// Finally, create the thunk and set the launch dimensions.
|
||||
auto thunk = absl::make_unique<gpu::KernelThunk>(
|
||||
buffers, kernel.getName().str(), instr,
|
||||
/*unroll_factor=*/1);
|
||||
|
||||
// Set launch bounds.
|
||||
mlir::gpu::KernelDim3 block = launchOp.getBlockSizeOperandValues();
|
||||
mlir::gpu::KernelDim3 grid = launchOp.getGridSizeOperandValues();
|
||||
absl::optional<int64> num_threads = getLaunchBound(block);
|
||||
absl::optional<int64> num_blocks = getLaunchBound(grid);
|
||||
if (!num_threads || !num_blocks) {
|
||||
return Unimplemented("Unsupported launch bounds");
|
||||
}
|
||||
thunk->SetLaunchDimensions(gpu::LaunchDimensions(*num_blocks, *num_threads));
|
||||
return std::move(thunk);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
StatusOr<std::unique_ptr<Executable>> MlirCompiler::RunBackend(
|
||||
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
||||
se::DeviceMemoryAllocator* device_allocator) {
|
||||
// Determine the HLO schedule, which is an ordering of HLO instructions. This
|
||||
// is used by buffer assignment to enable buffer reuse, and the same ordering
|
||||
// must also be used to determine the thunk launch schedule.
|
||||
std::unique_ptr<StreamAssignment> stream_assignment =
|
||||
xla::gpu::AssignStreams(*module);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<GpuHloSchedule> hlo_schedule,
|
||||
GpuHloSchedule::Build(*module, *stream_assignment, pointer_size_));
|
||||
|
||||
// Run buffer analysis on the HLO graph. This analysis figures out which
|
||||
// temporary buffers are required to run the computation.
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<BufferAssignment> buffer_assignment,
|
||||
BufferAssigner::Run(
|
||||
module.get(), hlo_schedule->ConsumeHloOrdering(),
|
||||
BufferSizeBytesFunction(),
|
||||
/*color_alignment=*/
|
||||
[](LogicalBuffer::Color) {
|
||||
return xla::gpu::kXlaAllocatedBufferAlignBytes;
|
||||
},
|
||||
/*allocate_buffers_for_constants=*/true,
|
||||
/*colorer=*/BufferAssigner::DefaultColorer(),
|
||||
/*must_not_live_out=*/{}, &CanShareBufferHint));
|
||||
DumpHloModuleIfEnabled(*module, *buffer_assignment, "after_optimizations");
|
||||
|
||||
EmissionContext emission_context(std::move(module));
|
||||
if (error_handler_) {
|
||||
emission_context.setErrorHandler(error_handler_);
|
||||
}
|
||||
|
||||
OwningModuleRef mlir_module =
|
||||
ModuleOp::create(UnknownLoc::get(emission_context.getContext()));
|
||||
LhloDialectEmitter lhlo_emitter(&emission_context, *buffer_assignment,
|
||||
stream_exec->platform(), *mlir_module);
|
||||
|
||||
TF_RETURN_IF_ERROR(lhlo_emitter.EmitComputation(
|
||||
*emission_context.getHloModule()->entry_computation()));
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
module_hook_.invoke(IRHook::LoweringStage::LHLO, *mlir_module));
|
||||
|
||||
TF_RETURN_IF_ERROR(LowerLHLOToGPU(*mlir_module));
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
module_hook_.invoke(IRHook::LoweringStage::GPU, *mlir_module));
|
||||
|
||||
TF_RETURN_IF_ERROR(LowerKernelBodiesToNVVM(*mlir_module));
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
module_hook_.invoke(IRHook::LoweringStage::LLVM, *mlir_module));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(OwningModuleRef kernel_module,
|
||||
ExtractKernelModule(*mlir_module));
|
||||
|
||||
auto thunk_sequence = lhlo_emitter.ConsumeThunkSequence();
|
||||
for (auto entry : lhlo_emitter.InstructionToFunctionMap()) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto thunk,
|
||||
TransformKernelToXlaThunk(entry.second, entry.first, *kernel_module,
|
||||
buffer_assignment.get()));
|
||||
thunk_sequence->push_back(std::move(thunk));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
module_hook_.invoke(IRHook::LoweringStage::KERNEL, *kernel_module));
|
||||
|
||||
auto llvmModule = mlir::translateModuleToNVVMIR(*kernel_module);
|
||||
|
||||
if (!llvmModule) {
|
||||
return InternalError("Translation to LLVM failed");
|
||||
}
|
||||
|
||||
llvmModule->setModuleIdentifier(emission_context.getHloModule()->name());
|
||||
// TODO(herhut): Why is this needed and does not come from the template?
|
||||
llvmModule->setDataLayout(gpu::nvptx::kDataLayout);
|
||||
|
||||
const auto& config = emission_context.getHloModule()->config();
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto ptx, xla::gpu::nvptx::CompileToPtx(llvmModule.get(),
|
||||
GetGpuVersion(stream_exec),
|
||||
config, GetLibdeviceDir(config)));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto cubin, se::CompileGpuAsm(stream_exec->device_ordinal(), ptx.c_str(),
|
||||
gpu::PtxOptsFromConfig(config)));
|
||||
|
||||
auto thunk_schedule = absl::make_unique<ThunkSchedule>(
|
||||
std::move(thunk_sequence), std::move(stream_assignment),
|
||||
hlo_schedule->ThunkLaunchOrder());
|
||||
|
||||
if (DumpingEnabledForHloModule(*emission_context.getHloModule())) {
|
||||
DumpToFileInDirOrStdout(*emission_context.getHloModule(), "",
|
||||
"thunk_schedule", thunk_schedule->ToString());
|
||||
}
|
||||
|
||||
// TODO(b/137624192): Add profiling support.
|
||||
return {absl::make_unique<GpuExecutable>(
|
||||
ptx, cubin, GetGpuVersion(stream_exec), std::move(thunk_schedule),
|
||||
emission_context.releaseHloModule(), std::move(buffer_assignment),
|
||||
nullptr, nullptr)};
|
||||
}
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<Executable>>> MlirCompiler::Compile(
|
||||
std::unique_ptr<HloModuleGroup> module_group,
|
||||
std::vector<std::vector<se::StreamExecutor*>> stream_execs,
|
||||
se::DeviceMemoryAllocator* device_allocator) {
|
||||
return Unimplemented("Not yet implemented in MLIR compiler");
|
||||
}
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||
MlirCompiler::CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
|
||||
const AotCompilationOptions& options) {
|
||||
return Unimplemented("Not yet implemented in MLIR compiler");
|
||||
}
|
||||
|
||||
void MlirCompiler::SetModuleHook(IRHook module_hook) {
|
||||
module_hook_ = module_hook;
|
||||
}
|
||||
@ -579,14 +63,3 @@ void MlirCompiler::RemoveErrorHandler() { error_handler_ = nullptr; }
|
||||
|
||||
} // namespace mlir_gpu
|
||||
} // namespace xla
|
||||
|
||||
static bool InitModule() {
|
||||
xla::Compiler::RegisterCompilerFactory(
|
||||
stream_executor::cuda::kCudaPlatformId, []() {
|
||||
return absl::make_unique<xla::FailoverCompiler>(
|
||||
absl::make_unique<xla::mlir_gpu::MlirCompiler>(),
|
||||
absl::make_unique<xla::gpu::NVPTXCompiler>());
|
||||
});
|
||||
return true;
|
||||
}
|
||||
static bool module_initialized = InitModule();
|
||||
|
||||
@ -16,7 +16,6 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_MLIR_COMPILER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_MLIR_COMPILER_H_
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/xla/service/compiler.h"
|
||||
@ -27,7 +26,8 @@ namespace mlir_gpu {
|
||||
|
||||
// A Compiler implementation that converts XLAs IR to a matching MLIR dialect,
|
||||
// performs all lowering on the MLIR IR and finally converts MLIR to LLVMIR for
|
||||
// generation of a think suitable for XLAs runtime.
|
||||
// generation of a thunk suitable for XLAs runtime. MlirCompilerImpl contains
|
||||
// the implementation.
|
||||
class MlirCompiler : public Compiler {
|
||||
using ErrorHandler =
|
||||
std::function<void(const EmissionContext::ErrorMap&, HloModule*)>;
|
||||
@ -37,30 +37,6 @@ class MlirCompiler : public Compiler {
|
||||
|
||||
se::Platform::Id PlatformId() const override;
|
||||
|
||||
StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
|
||||
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
||||
se::DeviceMemoryAllocator* device_allocator) override;
|
||||
|
||||
StatusOr<std::unique_ptr<Executable>> RunBackend(
|
||||
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
||||
se::DeviceMemoryAllocator* device_allocator) override;
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
|
||||
std::unique_ptr<HloModuleGroup> module_group,
|
||||
std::vector<std::vector<se::StreamExecutor*>> stream_execs,
|
||||
se::DeviceMemoryAllocator* device_allocator) override;
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||
CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
|
||||
const AotCompilationOptions& options) override;
|
||||
|
||||
HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override {
|
||||
int64 pointer_size = pointer_size_;
|
||||
return [pointer_size](const Shape& shape) {
|
||||
return ShapeUtil::ByteSizeOf(shape, pointer_size);
|
||||
};
|
||||
}
|
||||
|
||||
struct IRHook {
|
||||
enum class LoweringStage { LHLO, GPU, LLVM, KERNEL };
|
||||
|
||||
@ -80,7 +56,7 @@ class MlirCompiler : public Compiler {
|
||||
void SetErrorHandler(ErrorHandler error_handler);
|
||||
void RemoveErrorHandler();
|
||||
|
||||
private:
|
||||
protected:
|
||||
::mlir::MLIRContext context_;
|
||||
int64 pointer_size_;
|
||||
IRHook module_hook_;
|
||||
|
||||
584
tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc
Normal file
584
tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc
Normal file
@ -0,0 +1,584 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" // from @llvm-project
|
||||
#include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/Location.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/OperationSupport.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "mlir/Target/NVVMIR.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/dump.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_types.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/nvptx_compiler.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/stream_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/target_constants.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/mlir_gpu/emission_context.h"
|
||||
#include "tensorflow/compiler/xla/service/mlir_gpu/failover_compiler.h"
|
||||
#include "tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h"
|
||||
#include "tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h"
|
||||
#include "tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/platform/cuda_libdevice_path.h"
|
||||
#include "tensorflow/stream_executor/gpu/asm_compiler.h"
|
||||
|
||||
namespace xla {
|
||||
namespace mlir_gpu {
|
||||
namespace {
|
||||
|
||||
using ::mlir::BlockArgument;
|
||||
using ::mlir::dyn_cast;
|
||||
using ::mlir::FuncOp;
|
||||
using ::mlir::ModuleOp;
|
||||
using ::mlir::OwningModuleRef;
|
||||
using ::mlir::UnknownLoc;
|
||||
using ::mlir::Value;
|
||||
using ::mlir::gpu::LaunchFuncOp;
|
||||
using ::mlir::LLVM::LLVMDialect;
|
||||
using ::mlir::LLVM::LLVMFuncOp;
|
||||
using ::mlir::LLVM::LLVMType;
|
||||
using ::xla::gpu::GpuExecutable;
|
||||
using ::xla::gpu::GpuHloSchedule;
|
||||
using ::xla::gpu::GpuVersion;
|
||||
using ::xla::gpu::StreamAssignment;
|
||||
using ::xla::gpu::ThunkSchedule;
|
||||
|
||||
// A Compiler implementation that converts XLAs IR to a matching MLIR dialect,
|
||||
// performs all lowering on the MLIR IR and finally converts MLIR to LLVMIR for
|
||||
// generation of a thunk suitable for XLAs runtime.
|
||||
class MlirCompilerImpl : public MlirCompiler {
|
||||
public:
|
||||
StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
|
||||
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
||||
se::DeviceMemoryAllocator* device_allocator) override;
|
||||
|
||||
StatusOr<std::unique_ptr<Executable>> RunBackend(
|
||||
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
||||
se::DeviceMemoryAllocator* device_allocator) override;
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
|
||||
std::unique_ptr<HloModuleGroup> module_group,
|
||||
std::vector<std::vector<se::StreamExecutor*>> stream_execs,
|
||||
se::DeviceMemoryAllocator* device_allocator) override;
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||
CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
|
||||
const AotCompilationOptions& options) override;
|
||||
|
||||
HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override {
|
||||
int64 pointer_size = pointer_size_;
|
||||
return [pointer_size](const Shape& shape) {
|
||||
return ShapeUtil::ByteSizeOf(shape, pointer_size);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
// TODO(b/137624192) Share with NVPTX compiler
|
||||
static std::vector<std::string> CandidateCudaRoots(
|
||||
const HloModuleConfig& config) {
|
||||
return tensorflow::CandidateCudaRoots(
|
||||
config.debug_options().xla_gpu_cuda_data_dir());
|
||||
}
|
||||
|
||||
void PrintCantFindCudaMessage(absl::string_view msg,
|
||||
const HloModuleConfig& hlo_module_config) {
|
||||
LOG(WARNING) << msg;
|
||||
LOG(WARNING) << "Searched for CUDA in the following directories:";
|
||||
|
||||
for (const auto& dir : CandidateCudaRoots(hlo_module_config)) {
|
||||
LOG(WARNING) << " " << dir;
|
||||
}
|
||||
LOG(WARNING)
|
||||
<< "You can choose the search directory by setting xla_gpu_cuda_data_dir "
|
||||
"in HloModule's DebugOptions. For most apps, setting the environment "
|
||||
"variable XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda will work.";
|
||||
}
|
||||
|
||||
// Returns the directory containing nvvm libdevice files.
|
||||
std::string GetLibdeviceDir(const HloModuleConfig& hlo_module_config) {
|
||||
for (const string& cuda_root : CandidateCudaRoots(hlo_module_config)) {
|
||||
const std::string libdevice_dir =
|
||||
tensorflow::io::JoinPath(cuda_root, "nvvm", "libdevice");
|
||||
VLOG(2) << "Looking for libdevice at " << libdevice_dir;
|
||||
if (tensorflow::Env::Default()->IsDirectory(libdevice_dir).ok()) {
|
||||
VLOG(2) << "Found libdevice dir " << libdevice_dir;
|
||||
return libdevice_dir;
|
||||
}
|
||||
}
|
||||
PrintCantFindCudaMessage(
|
||||
"Can't find libdevice directory ${CUDA_DIR}/nvvm/libdevice. This may "
|
||||
"result in compilation or runtime failures, if the program we try to run "
|
||||
"uses routines from libdevice.",
|
||||
hlo_module_config);
|
||||
|
||||
// GetCudaRootCandidates always includes ".", but if everything fails, we
|
||||
// return it anyway. Better than returning the empty string.
|
||||
return ".";
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<HloModule>> MlirCompilerImpl::RunHloPasses(
|
||||
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
||||
se::DeviceMemoryAllocator* device_allocator) {
|
||||
// Until we find a reason to do something different, run the same passes
|
||||
// that the normal GPU backend runs.
|
||||
gpu::NVPTXCompiler xla_compiler;
|
||||
TF_RETURN_IF_ERROR(xla_compiler.OptimizeHloModule(module.get(), stream_exec,
|
||||
device_allocator));
|
||||
TF_RETURN_IF_ERROR(xla_compiler.PrepareHloModuleForIrEmitting(module.get()));
|
||||
|
||||
return std::move(module);
|
||||
}
|
||||
|
||||
// TODO(b/137624192): Move this to custom call handling and share.
|
||||
absl::optional<bool> CanShareBufferHint(const HloInstruction* user,
|
||||
const HloInstruction* operand,
|
||||
const ShapeIndex& user_index) {
|
||||
if (user->opcode() == HloOpcode::kCustomCall) {
|
||||
// Share the bias buffer with the parent instruction.
|
||||
if (user->custom_call_target() == xla::gpu::kGemmCallTarget) {
|
||||
if (user->operand_count() == 3 && user->operand(2) == operand) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
// The operand of cholesky can be shared with the first output.
|
||||
if (user->custom_call_target() == xla::gpu::kCusolverCholeskyCallTarget) {
|
||||
return user_index.size() == 1 && user_index[0] == 0;
|
||||
}
|
||||
}
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
// TODO(b/137624192): Share this with nvptx backend.
|
||||
GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) {
|
||||
int cc_major, cc_minor;
|
||||
const auto& device_description = stream_exec->GetDeviceDescription();
|
||||
if (!device_description.cuda_compute_capability(&cc_major, &cc_minor)) {
|
||||
LOG(WARNING)
|
||||
<< "Couldn't get compute capability for device; assuming sm_20.";
|
||||
cc_major = 2;
|
||||
cc_minor = 0;
|
||||
}
|
||||
return std::make_pair(cc_major, cc_minor);
|
||||
}
|
||||
|
||||
// Return the constant launch bound along the "x" dimension in "dim" if all the
|
||||
// other dimensions are 1. Return nullopt otherwise or when any of the bounds
|
||||
// is not constant.
|
||||
static absl::optional<int64> getLaunchBound(const mlir::gpu::KernelDim3& dim) {
|
||||
auto get_constant = [](mlir::Operation* op,
|
||||
mlir::StringRef name) -> absl::optional<int64> {
|
||||
if (auto constant = llvm::dyn_cast_or_null<mlir::ConstantOp>(op)) {
|
||||
return constant.value().cast<mlir::IntegerAttr>().getInt();
|
||||
}
|
||||
op->emitError() << "bound " << name << " is not constant";
|
||||
return absl::nullopt;
|
||||
};
|
||||
auto y_op = dim.y.getDefiningOp();
|
||||
auto dim_y = get_constant(y_op, "y");
|
||||
if (!dim_y.has_value() || dim_y.value() != 1) {
|
||||
y_op->emitError() << "bound 'y' is not constant 1";
|
||||
return absl::nullopt;
|
||||
}
|
||||
auto z_op = dim.z.getDefiningOp();
|
||||
auto dim_z = get_constant(z_op, "z");
|
||||
if (!dim_z.has_value() || dim_z.value() != 1) {
|
||||
z_op->emitError() << "bound 'z' is not constant 1";
|
||||
return absl::nullopt;
|
||||
}
|
||||
return get_constant(dim.x.getDefiningOp(), "x");
|
||||
}
|
||||
|
||||
// Indexes of a range of arguments in a GPU function. This is used to keep the
|
||||
// range of arguments that correspond to a lowered kernel argument of
|
||||
// (previously) memref type.
|
||||
struct LaunchFuncArgument {
|
||||
int kernel_argument_begin;
|
||||
int kernel_argument_size;
|
||||
};
|
||||
|
||||
using OperandToValueMap =
|
||||
absl::flat_hash_map<const HloInstruction*, std::vector<LaunchFuncArgument>>;
|
||||
|
||||
static StatusOr<std::vector<const HloInstruction*>> ComputeOperandToValueMap(
|
||||
OperandToValueMap* operand_to_value_map, const HloInstruction* instr,
|
||||
LaunchFuncOp launchOp, LLVMFuncOp kernel) {
|
||||
auto operands = instr->operands();
|
||||
std::vector<const HloInstruction*> ordered_operands;
|
||||
bool has_failed = false;
|
||||
// A memref will expand into multiple kernel operands, accumulate their number
|
||||
// in order to find them later.
|
||||
int cur_operand_position = 0;
|
||||
|
||||
for (int kernel_index = 0; kernel_index < launchOp.getNumKernelOperands();
|
||||
++kernel_index) {
|
||||
auto launchop_operand =
|
||||
launchOp.getKernelOperand(kernel_index).dyn_cast<BlockArgument>();
|
||||
if (!launchop_operand) {
|
||||
launchOp.emitError("argument to kernel is not a function input");
|
||||
has_failed = true;
|
||||
continue;
|
||||
}
|
||||
auto memref_type =
|
||||
launchop_operand.getType().dyn_cast<::mlir::MemRefType>();
|
||||
if (!memref_type) {
|
||||
launchOp.emitError("only memref-typed arguments are supported");
|
||||
has_failed = true;
|
||||
break;
|
||||
}
|
||||
// host_index is the argument position to the surrounding function that
|
||||
// contains the launch. This index corresponds to HLO operand indices
|
||||
// by construction.
|
||||
auto host_index = launchop_operand.getArgNumber();
|
||||
// The trailing argument to the outer function are the results.
|
||||
auto operand =
|
||||
(host_index < operands.size()) ? operands[host_index] : instr;
|
||||
if (!operand_to_value_map->count(operand)) {
|
||||
ordered_operands.push_back(operand);
|
||||
}
|
||||
// Associate the HLO operand with the argument values of the kernel
|
||||
// function.
|
||||
int num_unpacked =
|
||||
mlir::MemRefDescriptor::getNumUnpackedValues(memref_type);
|
||||
(*operand_to_value_map)[operand].push_back(
|
||||
{cur_operand_position, num_unpacked});
|
||||
cur_operand_position += num_unpacked;
|
||||
}
|
||||
if (has_failed) {
|
||||
return InternalError("Mapping operands to kernel arguments has failed.");
|
||||
}
|
||||
return ordered_operands;
|
||||
}
|
||||
|
||||
Status InsertBufferLoadPreduleIntoKernel(
|
||||
LLVMFuncOp kernel, const OperandToValueMap& operand_to_value_map,
|
||||
const std::vector<const HloInstruction*>& ordered_operands,
|
||||
BufferAssignment* assignment,
|
||||
const std::vector<const BufferAllocation*>& buffers) {
|
||||
mlir::OpBuilder builder(kernel.getBody());
|
||||
auto llvm_dialect = kernel.getContext()->getRegisteredDialect<LLVMDialect>();
|
||||
auto offset_type = LLVMType::getInt64Ty(llvm_dialect);
|
||||
auto ptr_type = LLVMType::getInt8PtrTy(llvm_dialect);
|
||||
auto void_type = LLVMType::getVoidTy(llvm_dialect);
|
||||
auto loc = kernel.getLoc();
|
||||
|
||||
auto num_original_args = kernel.getNumArguments();
|
||||
std::vector<LLVMType> new_arg_types(buffers.size(), ptr_type);
|
||||
kernel.setAttr(kernel.getTypeAttrName(),
|
||||
mlir::TypeAttr::get(LLVMType::getFunctionTy(
|
||||
void_type, new_arg_types, /*isVarArg=*/false)));
|
||||
std::vector<Value> original_args(kernel.args_begin(), kernel.args_end());
|
||||
|
||||
std::vector<mlir::Type> as_mlir_types(new_arg_types.begin(),
|
||||
new_arg_types.end());
|
||||
auto new_args = kernel.front().addArguments(as_mlir_types);
|
||||
std::vector<Value> buffer_args(new_args.begin(), new_args.end());
|
||||
|
||||
for (auto operand : ordered_operands) {
|
||||
TF_ASSIGN_OR_RETURN(auto slice,
|
||||
assignment->GetUniqueTopLevelSlice(operand));
|
||||
auto buffer = std::find(buffers.begin(), buffers.end(), slice.allocation());
|
||||
auto index = buffer - buffers.begin();
|
||||
auto offset = builder.create<mlir::LLVM::ConstantOp>(
|
||||
loc, offset_type, builder.getI64IntegerAttr(slice.offset()));
|
||||
auto ptr = buffer_args[index];
|
||||
|
||||
// Replace uses of function arguments pertaining to memref descriptors with
|
||||
// values derived from HLO buffers. The instructions inserting these values
|
||||
// into memref descriptors were already introduced during the lowering phase
|
||||
// as per MLIR calling convention.
|
||||
for (auto arg : operand_to_value_map.at(operand)) {
|
||||
mlir::MemRefDescriptorView original(
|
||||
mlir::ValueRange(original_args)
|
||||
.slice(arg.kernel_argument_begin, arg.kernel_argument_size));
|
||||
|
||||
// Allocated and aligned pointers are the same.
|
||||
auto casted = builder.create<mlir::LLVM::BitcastOp>(
|
||||
loc, original.alignedPtr().getType().cast<LLVMType>(),
|
||||
mlir::ValueRange(ptr));
|
||||
original.alignedPtr().replaceAllUsesWith(casted);
|
||||
original.allocatedPtr().replaceAllUsesWith(casted);
|
||||
|
||||
// Use the offset of the HLO buffer instead of the one expected in the
|
||||
// function call.
|
||||
original.offset().replaceAllUsesWith(offset);
|
||||
|
||||
// Fill the shape.
|
||||
auto shape = operand->shape();
|
||||
// Unless the operand is a scalar pointer, also fill shape and strides.
|
||||
if (shape.dimensions().empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// TODO(b/137624192) Pass in the descriptor to allow for dynamic shapes.
|
||||
assert(shape.IsArray() && shape.is_static());
|
||||
for (auto extent : llvm::enumerate(shape.dimensions())) {
|
||||
auto shape = builder.create<mlir::LLVM::ConstantOp>(
|
||||
loc, original.size(extent.index()).getType(),
|
||||
builder.getI64IntegerAttr(extent.value()));
|
||||
original.size(extent.index()).replaceAllUsesWith(shape);
|
||||
}
|
||||
// Finally, fill the strides.
|
||||
// TODO(b/137624192): Take assigned layout into account.
|
||||
uint64_t accumulator = 0;
|
||||
for (int64_t idx = shape.rank() - 1; idx >= 0; --idx) {
|
||||
if (accumulator == 0) {
|
||||
accumulator = 1;
|
||||
} else {
|
||||
accumulator *= shape.dimensions(idx + 1);
|
||||
}
|
||||
auto stride = builder.create<mlir::LLVM::ConstantOp>(
|
||||
loc, original.stride(idx).getType(),
|
||||
builder.getI64IntegerAttr(accumulator));
|
||||
original.stride(idx).replaceAllUsesWith(stride);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Now we can remove the original arguments, as they should have no more
|
||||
// users.
|
||||
for (int i = 0; i < num_original_args; ++i) {
|
||||
kernel.front().eraseArgument(0);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<gpu::KernelThunk>> TransformKernelToXlaThunk(
|
||||
FuncOp func, const HloInstruction* const instr, ModuleOp kernel_module,
|
||||
BufferAssignment* assignment) {
|
||||
// Find the single LaunchFuncOp and compute a mapping from operands of
|
||||
// the hlo instruction to the corresponding values of the kernel
|
||||
// function in the target module;
|
||||
LaunchFuncOp launchOp;
|
||||
auto walkResult = func.walk([&launchOp](LaunchFuncOp op) {
|
||||
if (launchOp) {
|
||||
op.emitError("multiple kernels for single top-level HLO");
|
||||
return mlir::WalkResult::interrupt();
|
||||
}
|
||||
launchOp = op;
|
||||
return mlir::WalkResult::advance();
|
||||
});
|
||||
if (walkResult.wasInterrupted()) {
|
||||
return InternalError("Multiple kernels for single top-level HLO");
|
||||
}
|
||||
if (!launchOp) {
|
||||
// If there was no launchOp, then no kernel was generated, so the lowering
|
||||
// from the LHLO ops to the GPU dialect is not implemented yet.
|
||||
return Unimplemented("No kernel was generated.");
|
||||
}
|
||||
|
||||
auto kernel = kernel_module.lookupSymbol<LLVMFuncOp>(launchOp.kernel());
|
||||
|
||||
// Store the assignment of operands to block arguments. Note that an operand
|
||||
// might be used in multiple argument positions, hence the vector.
|
||||
OperandToValueMap operand_to_value_map;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto ordered_operands,
|
||||
ComputeOperandToValueMap(&operand_to_value_map, instr, launchOp, kernel));
|
||||
|
||||
// Get the required buffers to support the inputs. Use a set and vector here
|
||||
// to keep the order fixed. This is mostly useful for testing.
|
||||
std::unordered_set<const BufferAllocation*> buffers_needed;
|
||||
std::vector<const BufferAllocation*> buffers;
|
||||
// TODO(b/137624192) Add support for tuples.
|
||||
for (auto operand : ordered_operands) {
|
||||
TF_ASSIGN_OR_RETURN(auto buffer,
|
||||
assignment->GetUniqueTopLevelSlice(operand));
|
||||
if (buffers_needed.insert(buffer.allocation()).second) {
|
||||
buffers.push_back(buffer.allocation());
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(b/137624192) Add support for temp buffer.
|
||||
// TODO(b/137624192) Add support for constant buffers.
|
||||
|
||||
// Change the signature to match what the XLA runtime expects from the
|
||||
// kernel.
|
||||
TF_RETURN_IF_ERROR(InsertBufferLoadPreduleIntoKernel(
|
||||
kernel, operand_to_value_map, ordered_operands, assignment, buffers));
|
||||
|
||||
// Finally, create the thunk and set the launch dimensions.
|
||||
auto thunk = absl::make_unique<gpu::KernelThunk>(
|
||||
buffers, kernel.getName().str(), instr,
|
||||
/*unroll_factor=*/1);
|
||||
|
||||
// Set launch bounds.
|
||||
mlir::gpu::KernelDim3 block = launchOp.getBlockSizeOperandValues();
|
||||
mlir::gpu::KernelDim3 grid = launchOp.getGridSizeOperandValues();
|
||||
absl::optional<int64> num_threads = getLaunchBound(block);
|
||||
absl::optional<int64> num_blocks = getLaunchBound(grid);
|
||||
if (!num_threads || !num_blocks) {
|
||||
return Unimplemented("Unsupported launch bounds");
|
||||
}
|
||||
thunk->SetLaunchDimensions(gpu::LaunchDimensions(*num_blocks, *num_threads));
|
||||
return std::move(thunk);
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Executable>> MlirCompilerImpl::RunBackend(
|
||||
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
||||
se::DeviceMemoryAllocator* device_allocator) {
|
||||
// Determine the HLO schedule, which is an ordering of HLO instructions. This
|
||||
// is used by buffer assignment to enable buffer reuse, and the same ordering
|
||||
// must also be used to determine the thunk launch schedule.
|
||||
std::unique_ptr<StreamAssignment> stream_assignment =
|
||||
xla::gpu::AssignStreams(*module);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<GpuHloSchedule> hlo_schedule,
|
||||
GpuHloSchedule::Build(*module, *stream_assignment, pointer_size_));
|
||||
|
||||
// Run buffer analysis on the HLO graph. This analysis figures out which
|
||||
// temporary buffers are required to run the computation.
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<BufferAssignment> buffer_assignment,
|
||||
BufferAssigner::Run(
|
||||
module.get(), hlo_schedule->ConsumeHloOrdering(),
|
||||
BufferSizeBytesFunction(),
|
||||
/*color_alignment=*/
|
||||
[](LogicalBuffer::Color) {
|
||||
return xla::gpu::kXlaAllocatedBufferAlignBytes;
|
||||
},
|
||||
/*allocate_buffers_for_constants=*/true,
|
||||
/*colorer=*/BufferAssigner::DefaultColorer(),
|
||||
/*must_not_live_out=*/{}, &CanShareBufferHint));
|
||||
DumpHloModuleIfEnabled(*module, *buffer_assignment, "after_optimizations");
|
||||
|
||||
EmissionContext emission_context(std::move(module));
|
||||
if (error_handler_) {
|
||||
emission_context.setErrorHandler(error_handler_);
|
||||
}
|
||||
|
||||
OwningModuleRef mlir_module =
|
||||
ModuleOp::create(UnknownLoc::get(emission_context.getContext()));
|
||||
LhloDialectEmitter lhlo_emitter(&emission_context, *buffer_assignment,
|
||||
stream_exec->platform(), *mlir_module);
|
||||
|
||||
TF_RETURN_IF_ERROR(lhlo_emitter.EmitComputation(
|
||||
*emission_context.getHloModule()->entry_computation()));
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
module_hook_.invoke(IRHook::LoweringStage::LHLO, *mlir_module));
|
||||
|
||||
TF_RETURN_IF_ERROR(LowerLHLOToGPU(*mlir_module));
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
module_hook_.invoke(IRHook::LoweringStage::GPU, *mlir_module));
|
||||
|
||||
TF_RETURN_IF_ERROR(LowerKernelBodiesToNVVM(*mlir_module));
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
module_hook_.invoke(IRHook::LoweringStage::LLVM, *mlir_module));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(OwningModuleRef kernel_module,
|
||||
ExtractKernelModule(*mlir_module));
|
||||
|
||||
auto thunk_sequence = lhlo_emitter.ConsumeThunkSequence();
|
||||
for (auto entry : lhlo_emitter.InstructionToFunctionMap()) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto thunk,
|
||||
TransformKernelToXlaThunk(entry.second, entry.first, *kernel_module,
|
||||
buffer_assignment.get()));
|
||||
thunk_sequence->push_back(std::move(thunk));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
module_hook_.invoke(IRHook::LoweringStage::KERNEL, *kernel_module));
|
||||
|
||||
auto llvmModule = mlir::translateModuleToNVVMIR(*kernel_module);
|
||||
|
||||
if (!llvmModule) {
|
||||
return InternalError("Translation to LLVM failed");
|
||||
}
|
||||
|
||||
llvmModule->setModuleIdentifier(emission_context.getHloModule()->name());
|
||||
// TODO(herhut): Why is this needed and does not come from the template?
|
||||
llvmModule->setDataLayout(gpu::nvptx::kDataLayout);
|
||||
|
||||
const auto& config = emission_context.getHloModule()->config();
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto ptx, xla::gpu::nvptx::CompileToPtx(llvmModule.get(),
|
||||
GetGpuVersion(stream_exec),
|
||||
config, GetLibdeviceDir(config)));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto cubin, se::CompileGpuAsm(stream_exec->device_ordinal(), ptx.c_str(),
|
||||
gpu::PtxOptsFromConfig(config)));
|
||||
|
||||
auto thunk_schedule = absl::make_unique<ThunkSchedule>(
|
||||
std::move(thunk_sequence), std::move(stream_assignment),
|
||||
hlo_schedule->ThunkLaunchOrder());
|
||||
|
||||
if (DumpingEnabledForHloModule(*emission_context.getHloModule())) {
|
||||
DumpToFileInDirOrStdout(*emission_context.getHloModule(), "",
|
||||
"thunk_schedule", thunk_schedule->ToString());
|
||||
}
|
||||
|
||||
// TODO(b/137624192): Add profiling support.
|
||||
return {absl::make_unique<GpuExecutable>(
|
||||
ptx, cubin, GetGpuVersion(stream_exec), std::move(thunk_schedule),
|
||||
emission_context.releaseHloModule(), std::move(buffer_assignment),
|
||||
nullptr, nullptr)};
|
||||
}
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<Executable>>> MlirCompilerImpl::Compile(
|
||||
std::unique_ptr<HloModuleGroup> module_group,
|
||||
std::vector<std::vector<se::StreamExecutor*>> stream_execs,
|
||||
se::DeviceMemoryAllocator* device_allocator) {
|
||||
return Unimplemented("Not yet implemented in MLIR compiler");
|
||||
}
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||
MlirCompilerImpl::CompileAheadOfTime(
|
||||
std::unique_ptr<HloModuleGroup> /*module_group*/,
|
||||
const AotCompilationOptions& /*options*/) {
|
||||
return Unimplemented("Not yet implemented in MLIR compiler");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mlir_gpu
|
||||
} // namespace xla
|
||||
|
||||
static bool InitModule() {
|
||||
xla::Compiler::RegisterCompilerFactory(
|
||||
stream_executor::cuda::kCudaPlatformId, []() {
|
||||
return absl::make_unique<xla::FailoverCompiler>(
|
||||
absl::make_unique<xla::mlir_gpu::MlirCompilerImpl>(),
|
||||
absl::make_unique<xla::gpu::NVPTXCompiler>());
|
||||
});
|
||||
return true;
|
||||
}
|
||||
static bool module_initialized = InitModule();
|
||||
Loading…
x
Reference in New Issue
Block a user