From 01b38cd7c651e0e83d7503671669ea9eb13afe81 Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Wed, 20 May 2020 06:40:25 -0700 Subject: [PATCH] [XLA:CPU] Plumb through a minimal emitter for matmuls using the mlir linalg dialect This is just the most basic lowering and will generate linalg.matmul for small matmuls and then convert to loops. The result is fairly slow, but we can iterate on that. To make XLA use it set XLA_FLAGS=--xla_backend_extra_options=xla_use_linalg_for_dot PiperOrigin-RevId: 312471829 Change-Id: I213d1f6114671bc595ac1647d3689736ee8f56f4 --- tensorflow/compiler/xla/service/cpu/BUILD | 30 ++++ .../compiler/xla/service/cpu/cpu_compiler.cc | 23 ++- .../compiler/xla/service/cpu/cpu_options.cc | 7 + .../compiler/xla/service/cpu/cpu_options.h | 1 + .../xla/service/cpu/dot_op_emitter.cc | 89 +++++++++--- .../compiler/xla/service/cpu/dot_op_emitter.h | 3 +- .../compiler/xla/service/cpu/ir_emitter.cc | 15 +- .../compiler/xla/service/cpu/ir_emitter.h | 11 +- .../compiler/xla/service/cpu/mlir_emitter.cc | 132 ++++++++++++++++++ .../compiler/xla/service/cpu/mlir_emitter.h | 43 ++++++ 10 files changed, 315 insertions(+), 39 deletions(-) create mode 100644 tensorflow/compiler/xla/service/cpu/mlir_emitter.cc create mode 100644 tensorflow/compiler/xla/service/cpu/mlir_emitter.h diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 2f432cd9356..3460e65b0a2 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -118,6 +118,9 @@ cc_library( ":target_machine_features", "@com_google_absl//absl/base", "@com_google_absl//absl/types:span", + "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", + "@llvm-project//mlir:ExecutionEngineUtils", + "@llvm-project//mlir:LLVMDialect", "//tensorflow/compiler/xla/service:copy_insertion", "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:dump", @@ -366,6 +369,7 @@ cc_library( "@llvm-project//llvm:core", "@llvm-project//llvm:support", "@llvm-project//llvm:target", + "@llvm-project//mlir:IR", ], ) @@ -456,6 +460,7 @@ cc_library( ":cpu_options", ":cpu_runtime", ":ir_emission_utils", + ":mlir_emitter", ":target_machine_features", ":tiled_dot_emitter", ":vector_support_library", @@ -474,6 +479,10 @@ cc_library( "//tensorflow/core:lib", "@com_google_absl//absl/strings", "@llvm-project//llvm:core", + "@llvm-project//mlir:EDSC", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LinalgOps", + "@llvm-project//mlir:StandardOps", ], ) @@ -1070,3 +1079,24 @@ tf_cc_test( "@llvm-project//llvm:target", ], ) + +cc_library( + name = "mlir_emitter", + srcs = ["mlir_emitter.cc"], + hdrs = ["mlir_emitter.h"], + deps = [ + "//tensorflow/compiler/mlir/xla:hlo_utils", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status", + "@llvm-project//llvm:core", + "@llvm-project//llvm:ipo", + "@llvm-project//llvm:linker", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMTransforms", + "@llvm-project//mlir:LinalgToLLVM", + "@llvm-project//mlir:LinalgTransforms", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:TargetLLVMIR", + "@llvm-project//mlir:VectorToLLVM", + ], +) diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index fe769bbdd2a..b2416ac2799 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -42,6 +42,8 @@ limitations under the License. #include "llvm/Support/TargetSelect.h" #include "llvm/Target/TargetMachine.h" #include "llvm/Target/TargetOptions.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project +#include "mlir/InitAllDialects.h" // from @llvm-project #include "tensorflow/compiler/xla/cpu_function_runtime.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/map_util.h" @@ -158,6 +160,8 @@ CpuCompiler::CpuCompiler() { // Initialize LLVM's MC layer for the native target. llvm::InitializeNativeTarget(); llvm::InitializeNativeTargetAsmPrinter(); + + mlir::registerAllDialects(); } namespace { @@ -606,9 +610,11 @@ StatusOr> CpuCompiler::RunBackend( user_post_optimization_hook_); // Compile must be thread-safe so create a new LLVM context for the module. - auto llvm_context = absl::make_unique(); - auto llvm_module = - absl::make_unique("__compute_module", *llvm_context); + mlir::MLIRContext mlir_context; + auto llvm_module = absl::make_unique( + "__compute_module", + mlir_context.getRegisteredDialect() + ->getLLVMContext()); auto jit = absl::make_unique( CompilerTargetOptions(module->config()), @@ -662,7 +668,7 @@ StatusOr> CpuCompiler::RunBackend( // before a caller computation. LLVMTargetMachineFeatures target_machine_features(jit->target_machine()); - IrEmitter ir_emitter(*module, *assignment, llvm_module.get(), + IrEmitter ir_emitter(&mlir_context, *module, *assignment, llvm_module.get(), std::move(instruction_to_profile_idx), std::move(computation_to_profile_idx), &target_machine_features, @@ -816,8 +822,11 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, opt_level)); // Compile must be thread-safe so create a new LLVM context for the module. - llvm::LLVMContext llvm_context; - llvm::Module llvm_module("__compute_module", llvm_context); + mlir::MLIRContext mlir_context; + llvm::Module llvm_module( + "__compute_module", + mlir_context.getRegisteredDialect() + ->getLLVMContext()); llvm_module.setDataLayout(target_machine->createDataLayout()); llvm_module.setTargetTriple(triple.getTriple()); if (pic_level != llvm::PICLevel::NotPIC) { @@ -866,7 +875,7 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, } LLVMTargetMachineFeatures target_machine_features(target_machine.get()); - IrEmitter ir_emitter(*module, *assignment, &llvm_module, + IrEmitter ir_emitter(&mlir_context, *module, *assignment, &llvm_module, std::move(instruction_to_profile_idx), std::move(computation_to_profile_idx), &target_machine_features, diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.cc b/tensorflow/compiler/xla/service/cpu/cpu_options.cc index ff654c83d61..c0222010fd9 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_options.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_options.cc @@ -25,6 +25,7 @@ const char* const kXlaOptimizeForSizeCpuOption = "xla_cpu_optimize_for_size"; const char* const kLlvmIrDotTilingFactor = "xla_llvm_dot_tiling_factor"; const char* const kXlaForceEnableExperimentalLlvmIrGemm = "xla_force_enable_experimental_llvm_ir_gemm"; +const char* const kXlaUseLinalgForDot = "xla_use_linalg_for_dot"; const char* const kLlvmIrGemmTileSize = "xla_llvm_ir_gemm_tile_size"; } // namespace @@ -63,6 +64,12 @@ bool ForceEnableExperimentalLlvmIrGemm(const HloModuleConfig& config) { return extra_options_map.count(kXlaForceEnableExperimentalLlvmIrGemm) > 0; } +bool UseLinalgForDot(const HloModuleConfig& config) { + const auto& extra_options_map = + config.debug_options().xla_backend_extra_options(); + return extra_options_map.count(kXlaUseLinalgForDot) > 0; +} + static absl::string_view RemoveSuffix(absl::string_view str, absl::string_view suffix) { CHECK_GE(str.size(), suffix.size()); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.h b/tensorflow/compiler/xla/service/cpu/cpu_options.h index 99e6702d14a..5d25aef6912 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_options.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_options.h @@ -27,6 +27,7 @@ namespace options { bool OptimizeForSizeRequested(const HloModuleConfig& config); bool VectorizedReduceDisabled(const HloModuleConfig& config); bool ForceEnableExperimentalLlvmIrGemm(const HloModuleConfig& config); +bool UseLinalgForDot(const HloModuleConfig& config); absl::optional LlvmIrGemvTilingFactor(const HloModuleConfig& config); absl::optional> LlvmIrGemmTileSize( const HloModuleConfig& config); diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index 7dba826b65c..e1ad14600d7 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -23,8 +23,17 @@ limitations under the License. #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" #include "llvm/IR/Value.h" +#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" // from @llvm-project +#include "mlir/EDSC/Builders.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "tensorflow/compiler/xla/service/cpu/cpu_options.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/cpu/mlir_emitter.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h" #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h" @@ -89,6 +98,9 @@ enum class DotImplementationStrategy { // and the output have to be row major. kTiledLlvmIrGemm, + // The dot operation is lowered into linalg.matmul op and lowered to LLVM IR. + kLinalgMatmul, + // The dot operation is lowered into a call into an Eigen routine. No fusions // are supported today. The two inputs and the output have to be row major. // However, we do allow transposing either the LHS or the RHS as part of the @@ -112,7 +124,7 @@ class DotOpEmitter { const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, llvm::Value* executable_run_options_value, - llvm::IRBuilder<>* b, + llvm::IRBuilder<>* b, mlir::MLIRContext* mlir_context, const HloModuleConfig& hlo_module_config, const TargetMachineFeatures& target_machine_features); @@ -163,6 +175,9 @@ class DotOpEmitter { // Lowers the dot operation as a tiled Matrix*Matrix loop. void EmitTiledLlvmIrGemm(); + // Lowers the dot operation through MLIR's linalg.matmul. + Status EmitLinalgMatmul(); + // Lowers the dot operation as a naive nested loop that computes the result // one element at a time. void EmitNaiveLlvmIrGemm(); @@ -194,20 +209,19 @@ class DotOpEmitter { const llvm_ir::IrArray* addend_array_; llvm::Value* executable_run_options_value_; llvm::IRBuilder<>* b_; + mlir::MLIRContext* mlir_context_; const HloModuleConfig& hlo_module_config_; const TargetMachineFeatures& target_machine_features_; }; } // namespace -DotOpEmitter::DotOpEmitter(DotInfo dot_info, string dot_hlo_name, - const llvm_ir::IrArray& target_array, - const llvm_ir::IrArray& lhs_array, - const llvm_ir::IrArray& rhs_array, - const llvm_ir::IrArray* addend_array, - llvm::Value* executable_run_options_value, - llvm::IRBuilder<>* b, - const HloModuleConfig& hlo_module_config, - const TargetMachineFeatures& target_machine_features) +DotOpEmitter::DotOpEmitter( + DotInfo dot_info, string dot_hlo_name, const llvm_ir::IrArray& target_array, + const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, + const llvm_ir::IrArray* addend_array, + llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b, + mlir::MLIRContext* mlir_context, const HloModuleConfig& hlo_module_config, + const TargetMachineFeatures& target_machine_features) : dot_info_(std::move(dot_info)), dot_hlo_name_(std::move(dot_hlo_name)), target_array_(target_array), @@ -216,9 +230,36 @@ DotOpEmitter::DotOpEmitter(DotInfo dot_info, string dot_hlo_name, addend_array_(addend_array), executable_run_options_value_(executable_run_options_value), b_(b), + mlir_context_(mlir_context), hlo_module_config_(hlo_module_config), target_machine_features_(target_machine_features) {} +Status DotOpEmitter::EmitLinalgMatmul() { + Shape operand_shapes[] = {dot_info_.lhs_shape, dot_info_.rhs_shape}; + llvm::Value* operand_ptrs[] = {lhs_array_.GetBasePointer(), + rhs_array_.GetBasePointer()}; + llvm::Value* target_ptr = target_array_.GetBasePointer(); + + // Zero out the output buffer. + int64 size_bytes = ShapeUtil::ByteSizeOf(dot_info_.result_shape); + b_->CreateMemSet(target_ptr, b_->getInt8(0), /*Size=*/size_bytes, + /*Align=*/llvm::MaybeAlign(1)); + + std::string name = + absl::StrCat("linalgMatMul_", dot_info_.result_shape.ToString(true), "_", + dot_info_.lhs_shape.ToString(true), "_", + dot_info_.rhs_shape.ToString(true)); + return EmitMlirFuncAndCall( + mlir_context_, b_, dot_info_.result_shape, operand_shapes, target_ptr, + operand_ptrs, name, [&](mlir::OpBuilder* builder, mlir::FuncOp function) { + mlir::edsc::ScopedContext scope(*builder, function.getLoc()); + mlir::Value a = function.getArgument(0), b = function.getArgument(1), + c = function.getArgument(2); + mlir::edsc::intrinsics::linalg_matmul(b, c, a); + mlir::edsc::intrinsics::std_ret(); + }); +} + void DotOpEmitter::EmitTiledLlvmIrGemm() { PrimitiveType primitive_type = dot_info_.result_shape.element_type(); MatMultDims mat_mult_dims = GetMatMultDims(); @@ -418,6 +459,9 @@ Status DotOpEmitter::Emit() { EmitTiledLlvmIrGemm(); return Status::OK(); + case DotImplementationStrategy::kLinalgMatmul: + return EmitLinalgMatmul(); + case DotImplementationStrategy::kEigen: return EmitCallToRuntime(); } @@ -886,9 +930,12 @@ DotImplementationStrategy GetDotImplementationStrategy( } if (IsAlignedGemm(dot_info, target_machine_features)) { - return CanEmitTiledLlvmIrGemm(config, dot_info, target_machine_features) - ? DotImplementationStrategy::kTiledLlvmIrGemm - : DotImplementationStrategy::kEigen; + if (CanEmitTiledLlvmIrGemm(config, dot_info, target_machine_features)) { + return options::UseLinalgForDot(config) + ? DotImplementationStrategy::kLinalgMatmul + : DotImplementationStrategy::kTiledLlvmIrGemm; + } + return DotImplementationStrategy::kEigen; } return DotImplementationStrategy::kNaiveLlvmIr; @@ -899,15 +946,15 @@ Status EmitNonBatchDotOperation( const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b, - const HloModuleConfig& hlo_module_config, + mlir::MLIRContext* mlir_context, const HloModuleConfig& hlo_module_config, const TargetMachineFeatures& target_machine_features) { PrimitiveType type = target_array.GetShape().element_type(); TF_RET_CHECK(S32 == type || F16 == type || F32 == type || F64 == type || C64 == type || C128 == type); DotOpEmitter dot_emitter(std::move(dot_info), std::move(hlo_name), target_array, lhs_array, rhs_array, addend_array, - executable_run_options_value, b, hlo_module_config, - target_machine_features); + executable_run_options_value, b, mlir_context, + hlo_module_config, target_machine_features); return dot_emitter.Emit(); } @@ -981,7 +1028,7 @@ Status EmitBatchDotOperation( const HloInstruction& dot, const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b, - const HloModuleConfig& hlo_module_config, + mlir::MLIRContext* mlir_context, const HloModuleConfig& hlo_module_config, const TargetMachineFeatures& target_machine_features) { TF_RETURN_IF_ERROR(ValidateDotDimensionNumbers(dot.dot_dimension_numbers())); @@ -1039,7 +1086,7 @@ Status EmitBatchDotOperation( // Emit the inner non-batch dot operation. return EmitNonBatchDotOperation( dot_info, dot.name(), target_slice, lhs_slice, rhs_slice, nullptr, - executable_run_options_value, b, hlo_module_config, + executable_run_options_value, b, mlir_context, hlo_module_config, target_machine_features); }); } @@ -1089,7 +1136,7 @@ Status EmitDotOperation(const HloInstruction& dot, const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, llvm::Value* executable_run_options_value, - llvm::IRBuilder<>* b, + llvm::IRBuilder<>* b, mlir::MLIRContext* mlir_context, const HloModuleConfig& hlo_module_config, const TargetMachineFeatures& target_machine_features) { // This routine assumes that the dot operation is not in a parallelized @@ -1099,13 +1146,13 @@ Status EmitDotOperation(const HloInstruction& dot, if (IsBatchDot(dot)) { TF_RET_CHECK(addend_array == nullptr); return EmitBatchDotOperation(dot, target_array, lhs_array, rhs_array, - executable_run_options_value, b, + executable_run_options_value, b, mlir_context, hlo_module_config, target_machine_features); } return EmitNonBatchDotOperation(DotInfo(dot), dot.name(), target_array, lhs_array, rhs_array, addend_array, - executable_run_options_value, b, + executable_run_options_value, b, mlir_context, hlo_module_config, target_machine_features); } } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h index 105bd3005c8..d9cf8a2036b 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h @@ -18,6 +18,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "llvm/IR/IRBuilder.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "tensorflow/compiler/xla/service/cpu/cpu_options.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -63,7 +64,7 @@ Status EmitDotOperation(const HloInstruction& dot, const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, llvm::Value* executable_run_options_value, - llvm::IRBuilder<>* b, + llvm::IRBuilder<>* b, mlir::MLIRContext* mlir_context, const HloModuleConfig& hlo_module_config, const TargetMachineFeatures& target_machine_features); } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 70dde919afb..043ad68a196 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -89,8 +89,8 @@ using llvm_ir::SetToFirstInsertPoint; namespace cpu { IrEmitter::IrEmitter( - const HloModule& hlo_module, const BufferAssignment& assignment, - llvm::Module* llvm_module, + mlir::MLIRContext* mlir_context, const HloModule& hlo_module, + const BufferAssignment& assignment, llvm::Module* llvm_module, std::unordered_map instruction_to_profile_idx, std::unordered_map computation_to_profile_idx, const TargetMachineFeatures* target_machine_features, @@ -99,6 +99,7 @@ IrEmitter::IrEmitter( module_(llvm_module), arch_type_(llvm::Triple(llvm_module->getTargetTriple()).getArch()), b_(llvm_module->getContext()), + mlir_context_(mlir_context), instruction_to_profile_idx_(std::move(instruction_to_profile_idx)), computation_to_profile_idx_(std::move(computation_to_profile_idx)), alias_analysis_(hlo_module, assignment, &llvm_module->getContext()), @@ -898,7 +899,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { // Dot operation is complicated so we delegate to a helper class. return EmitDotOperation(*dot, target_array, lhs_array, rhs_array, /*addend_array=*/nullptr, - GetExecutableRunOptionsArgument(), &b_, + GetExecutableRunOptionsArgument(), &b_, mlir_context_, hlo_module_config_, target_machine_features_); } @@ -2305,10 +2306,10 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { llvm_ir::IrArray addend_array( GetIrArrayFor(fusion->operand(addend_param_number))); - TF_RETURN_IF_ERROR( - EmitDotOperation(*dot, target_array, lhs_array, rhs_array, - &addend_array, GetExecutableRunOptionsArgument(), &b_, - hlo_module_config_, target_machine_features_)); + TF_RETURN_IF_ERROR(EmitDotOperation( + *dot, target_array, lhs_array, rhs_array, &addend_array, + GetExecutableRunOptionsArgument(), &b_, mlir_context_, + hlo_module_config_, target_machine_features_)); return Status::OK(); } else { return Unimplemented("Fusion kind not implemented on CPU"); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 9b0d11e9f3f..661785153d0 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMITTER_H_ #include + #include #include #include @@ -32,6 +33,7 @@ limitations under the License. #include "llvm/IR/Module.h" #include "llvm/IR/Value.h" #include "llvm/Target/TargetMachine.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/cpu/ir_function.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" @@ -69,14 +71,16 @@ class IrEmitter : public DfsHloVisitorWithDefault, // hlo_module: the HLO module we are emitting IR for. // assignment: a BufferAssignment from which we know which buffers are used by // the HLO nodes. - // llvm_module: the LLVM module to emit IR into. + // mlir_context: the MLIR context used for IR emission. + // llvm_module: the LLVM module to emit IR into. It's built using the LLVM + // context inside of mlir_context. // instruction_to_profile_idx: the mapping from HLO instructions to their // index in the profiling array. // computation_to_profile_idx: the mapping from HLO computations to their // index in the profiling array. // emit_code_for_msan: whether emitted code should be compatible with msan. - IrEmitter(const HloModule& hlo_module, const BufferAssignment& assignment, - llvm::Module* llvm_module, + IrEmitter(mlir::MLIRContext* mlir_context, const HloModule& hlo_module, + const BufferAssignment& assignment, llvm::Module* llvm_module, std::unordered_map instruction_to_profile_idx, std::unordered_map @@ -442,6 +446,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, // module's function list). std::unique_ptr compute_function_; llvm::IRBuilder<> b_; + mlir::MLIRContext* mlir_context_; // The buffer allocation slice for the root of the computation being compiled. // Only relevant for thread local computations. diff --git a/tensorflow/compiler/xla/service/cpu/mlir_emitter.cc b/tensorflow/compiler/xla/service/cpu/mlir_emitter.cc new file mode 100644 index 00000000000..e7d52c288d5 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/mlir_emitter.cc @@ -0,0 +1,132 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/mlir_emitter.h" + +#include "llvm/Linker/Linker.h" +#include "llvm/Transforms/IPO/Internalize.h" +#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" // from @llvm-project +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" // from @llvm-project +#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" // from @llvm-project +#include "mlir/Dialect/Linalg/Passes.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Target/LLVMIR.h" // from @llvm-project +#include "tensorflow/compiler/mlir/xla/hlo_utils.h" + +namespace xla { +namespace cpu { +namespace { + +// Lower an MLIR module to an LLVM module. +std::unique_ptr MakeLLVMModule(mlir::OwningModuleRef module) { + mlir::PassManager manager(module->getContext()); + manager.addPass(mlir::createConvertLinalgToLoopsPass()); + manager.addPass(mlir::createConvertLinalgToLLVMPass()); + manager.addPass(mlir::createConvertVectorToLLVMPass()); + manager.addPass(mlir::createLowerToLLVMPass()); + CHECK(succeeded(manager.run(*module))); + return mlir::translateModuleToLLVMIR(*module); +} + +// Get arguments to pass a memref to an mlir function. +void BuildViewForBuffer(llvm::SmallVectorImpl *args, + llvm::IRBuilder<> *b, const Shape &opShape, + llvm::Value *op_val) { + llvm::Type *ty = op_val->getType(); + while (auto aty = llvm::dyn_cast( + llvm::cast(ty)->getElementType())) { + ty = aty->getElementType()->getPointerTo(); + } + op_val = b->CreateBitCast(op_val, ty); + + args->push_back(op_val); // Allocated pointer. + args->push_back(op_val); // Aligned pointer. + args->push_back(b->getInt64(0)); // Offset. + + // Sizes. + for (int64 dim : opShape.dimensions()) { + args->push_back(b->getInt64(dim)); + } + + int64_t accumulated_stride = 1; + llvm::SmallVector strides(opShape.rank(), 1); + for (int64 dim : LayoutUtil::MinorToMajor(opShape)) { + strides[dim] = accumulated_stride; + accumulated_stride *= opShape.dimensions(dim); + } + + // Strides. + for (int64 stride : strides) { + args->push_back(b->getInt64(stride)); + } +} +} // namespace + +Status EmitMlirFuncAndCall( + mlir::MLIRContext *context, llvm::IRBuilder<> *b, const Shape &result_shape, + llvm::ArrayRef operand_shapes, llvm::Value *result_ptr, + llvm::ArrayRef operand_ptrs, llvm::StringRef func_name, + llvm::function_ref emitter) { + llvm::Module *llvm_module = b->GetInsertBlock()->getParent()->getParent(); + mlir::Builder mlir_builder(context); + + // Get memref types for the inputs and output. + TF_ASSIGN_OR_RETURN(mlir::Type ret_memref, ConvertTensorShapeToMemRefType( + result_shape, mlir_builder)); + std::vector operand_types = {ret_memref}; + for (int i = 0; i != operand_shapes.size(); ++i) { + TF_ASSIGN_OR_RETURN( + mlir::Type op_memref, + ConvertTensorShapeToMemRefType(operand_shapes[i], mlir_builder)); + operand_types.push_back(op_memref); + } + + // Create the function an call the emission callback. + mlir::Location loc = mlir::UnknownLoc::get(context); + auto function = mlir::FuncOp::create( + loc, func_name, mlir::FunctionType::get(operand_types, {}, context)); + function.addEntryBlock(); + mlir::OwningModuleRef mlir_module = mlir::ModuleOp::create(loc); + mlir_module->push_back(function); + mlir::OpBuilder op_builder(&function.getBody()); + emitter(&op_builder, function); + + // Now link it all into the main LLVM module. + auto mlir_llvm_module = MakeLLVMModule(std::move(mlir_module)); + mlir_llvm_module->setDataLayout(llvm_module->getDataLayout()); + llvm::Linker::linkModules( + *llvm_module, std::move(mlir_llvm_module), llvm::Linker::None, + [](llvm::Module &M, const llvm::StringSet<> &GVS) { + llvm::internalizeModule(M, [&GVS](const llvm::GlobalValue &GV) { + return !GV.hasName() || (GVS.count(GV.getName()) == 0); + }); + }); + + // And leave behind a call to the function generated by MLIR. + llvm::Function *func = llvm_module->getFunction(func_name); + llvm::SmallVector op_vals; + BuildViewForBuffer(&op_vals, b, result_shape, result_ptr); + for (int i = 0; i != operand_shapes.size(); ++i) { + BuildViewForBuffer(&op_vals, b, operand_shapes[i], operand_ptrs[i]); + } + b->CreateCall(func, op_vals); + + return Status::OK(); +} + +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/mlir_emitter.h b/tensorflow/compiler/xla/service/cpu/mlir_emitter.h new file mode 100644 index 00000000000..bc0741e851a --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/mlir_emitter.h @@ -0,0 +1,43 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_MLIR_EMITTER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_MLIR_EMITTER_H_ + +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Value.h" +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/status.h" + +namespace xla { +namespace cpu { + +// Create a new MLIR function with the name `func_name`, populate it with +// `emitter` and create a call, passing it the buffers defined by +// resultShape/resultPtr and operandShapes/operandPtrs. The function is added to +// the LLVM module at `b`s insertion point. +Status EmitMlirFuncAndCall( + mlir::MLIRContext *context, llvm::IRBuilder<> *b, const Shape &result_shape, + llvm::ArrayRef operand_shapes, llvm::Value *result_ptr, + llvm::ArrayRef operand_ptrs, llvm::StringRef func_name, + llvm::function_ref emitter); + +} // namespace cpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_MLIR_EMITTER_H_