diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 2f432cd9356..3460e65b0a2 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -118,6 +118,9 @@ cc_library( ":target_machine_features", "@com_google_absl//absl/base", "@com_google_absl//absl/types:span", + "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", + "@llvm-project//mlir:ExecutionEngineUtils", + "@llvm-project//mlir:LLVMDialect", "//tensorflow/compiler/xla/service:copy_insertion", "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:dump", @@ -366,6 +369,7 @@ cc_library( "@llvm-project//llvm:core", "@llvm-project//llvm:support", "@llvm-project//llvm:target", + "@llvm-project//mlir:IR", ], ) @@ -456,6 +460,7 @@ cc_library( ":cpu_options", ":cpu_runtime", ":ir_emission_utils", + ":mlir_emitter", ":target_machine_features", ":tiled_dot_emitter", ":vector_support_library", @@ -474,6 +479,10 @@ cc_library( "//tensorflow/core:lib", "@com_google_absl//absl/strings", "@llvm-project//llvm:core", + "@llvm-project//mlir:EDSC", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LinalgOps", + "@llvm-project//mlir:StandardOps", ], ) @@ -1070,3 +1079,24 @@ tf_cc_test( "@llvm-project//llvm:target", ], ) + +cc_library( + name = "mlir_emitter", + srcs = ["mlir_emitter.cc"], + hdrs = ["mlir_emitter.h"], + deps = [ + "//tensorflow/compiler/mlir/xla:hlo_utils", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status", + "@llvm-project//llvm:core", + "@llvm-project//llvm:ipo", + "@llvm-project//llvm:linker", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMTransforms", + "@llvm-project//mlir:LinalgToLLVM", + "@llvm-project//mlir:LinalgTransforms", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:TargetLLVMIR", + "@llvm-project//mlir:VectorToLLVM", + ], +) diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index fe769bbdd2a..b2416ac2799 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -42,6 +42,8 @@ limitations under the License. #include "llvm/Support/TargetSelect.h" #include "llvm/Target/TargetMachine.h" #include "llvm/Target/TargetOptions.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project +#include "mlir/InitAllDialects.h" // from @llvm-project #include "tensorflow/compiler/xla/cpu_function_runtime.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/map_util.h" @@ -158,6 +160,8 @@ CpuCompiler::CpuCompiler() { // Initialize LLVM's MC layer for the native target. llvm::InitializeNativeTarget(); llvm::InitializeNativeTargetAsmPrinter(); + + mlir::registerAllDialects(); } namespace { @@ -606,9 +610,11 @@ StatusOr> CpuCompiler::RunBackend( user_post_optimization_hook_); // Compile must be thread-safe so create a new LLVM context for the module. - auto llvm_context = absl::make_unique(); - auto llvm_module = - absl::make_unique("__compute_module", *llvm_context); + mlir::MLIRContext mlir_context; + auto llvm_module = absl::make_unique( + "__compute_module", + mlir_context.getRegisteredDialect() + ->getLLVMContext()); auto jit = absl::make_unique( CompilerTargetOptions(module->config()), @@ -662,7 +668,7 @@ StatusOr> CpuCompiler::RunBackend( // before a caller computation. LLVMTargetMachineFeatures target_machine_features(jit->target_machine()); - IrEmitter ir_emitter(*module, *assignment, llvm_module.get(), + IrEmitter ir_emitter(&mlir_context, *module, *assignment, llvm_module.get(), std::move(instruction_to_profile_idx), std::move(computation_to_profile_idx), &target_machine_features, @@ -816,8 +822,11 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, opt_level)); // Compile must be thread-safe so create a new LLVM context for the module. - llvm::LLVMContext llvm_context; - llvm::Module llvm_module("__compute_module", llvm_context); + mlir::MLIRContext mlir_context; + llvm::Module llvm_module( + "__compute_module", + mlir_context.getRegisteredDialect() + ->getLLVMContext()); llvm_module.setDataLayout(target_machine->createDataLayout()); llvm_module.setTargetTriple(triple.getTriple()); if (pic_level != llvm::PICLevel::NotPIC) { @@ -866,7 +875,7 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, } LLVMTargetMachineFeatures target_machine_features(target_machine.get()); - IrEmitter ir_emitter(*module, *assignment, &llvm_module, + IrEmitter ir_emitter(&mlir_context, *module, *assignment, &llvm_module, std::move(instruction_to_profile_idx), std::move(computation_to_profile_idx), &target_machine_features, diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.cc b/tensorflow/compiler/xla/service/cpu/cpu_options.cc index ff654c83d61..c0222010fd9 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_options.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_options.cc @@ -25,6 +25,7 @@ const char* const kXlaOptimizeForSizeCpuOption = "xla_cpu_optimize_for_size"; const char* const kLlvmIrDotTilingFactor = "xla_llvm_dot_tiling_factor"; const char* const kXlaForceEnableExperimentalLlvmIrGemm = "xla_force_enable_experimental_llvm_ir_gemm"; +const char* const kXlaUseLinalgForDot = "xla_use_linalg_for_dot"; const char* const kLlvmIrGemmTileSize = "xla_llvm_ir_gemm_tile_size"; } // namespace @@ -63,6 +64,12 @@ bool ForceEnableExperimentalLlvmIrGemm(const HloModuleConfig& config) { return extra_options_map.count(kXlaForceEnableExperimentalLlvmIrGemm) > 0; } +bool UseLinalgForDot(const HloModuleConfig& config) { + const auto& extra_options_map = + config.debug_options().xla_backend_extra_options(); + return extra_options_map.count(kXlaUseLinalgForDot) > 0; +} + static absl::string_view RemoveSuffix(absl::string_view str, absl::string_view suffix) { CHECK_GE(str.size(), suffix.size()); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.h b/tensorflow/compiler/xla/service/cpu/cpu_options.h index 99e6702d14a..5d25aef6912 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_options.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_options.h @@ -27,6 +27,7 @@ namespace options { bool OptimizeForSizeRequested(const HloModuleConfig& config); bool VectorizedReduceDisabled(const HloModuleConfig& config); bool ForceEnableExperimentalLlvmIrGemm(const HloModuleConfig& config); +bool UseLinalgForDot(const HloModuleConfig& config); absl::optional LlvmIrGemvTilingFactor(const HloModuleConfig& config); absl::optional> LlvmIrGemmTileSize( const HloModuleConfig& config); diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index 7dba826b65c..e1ad14600d7 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -23,8 +23,17 @@ limitations under the License. #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" #include "llvm/IR/Value.h" +#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" // from @llvm-project +#include "mlir/EDSC/Builders.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "tensorflow/compiler/xla/service/cpu/cpu_options.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/cpu/mlir_emitter.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h" #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h" @@ -89,6 +98,9 @@ enum class DotImplementationStrategy { // and the output have to be row major. kTiledLlvmIrGemm, + // The dot operation is lowered into linalg.matmul op and lowered to LLVM IR. + kLinalgMatmul, + // The dot operation is lowered into a call into an Eigen routine. No fusions // are supported today. The two inputs and the output have to be row major. // However, we do allow transposing either the LHS or the RHS as part of the @@ -112,7 +124,7 @@ class DotOpEmitter { const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, llvm::Value* executable_run_options_value, - llvm::IRBuilder<>* b, + llvm::IRBuilder<>* b, mlir::MLIRContext* mlir_context, const HloModuleConfig& hlo_module_config, const TargetMachineFeatures& target_machine_features); @@ -163,6 +175,9 @@ class DotOpEmitter { // Lowers the dot operation as a tiled Matrix*Matrix loop. void EmitTiledLlvmIrGemm(); + // Lowers the dot operation through MLIR's linalg.matmul. + Status EmitLinalgMatmul(); + // Lowers the dot operation as a naive nested loop that computes the result // one element at a time. void EmitNaiveLlvmIrGemm(); @@ -194,20 +209,19 @@ class DotOpEmitter { const llvm_ir::IrArray* addend_array_; llvm::Value* executable_run_options_value_; llvm::IRBuilder<>* b_; + mlir::MLIRContext* mlir_context_; const HloModuleConfig& hlo_module_config_; const TargetMachineFeatures& target_machine_features_; }; } // namespace -DotOpEmitter::DotOpEmitter(DotInfo dot_info, string dot_hlo_name, - const llvm_ir::IrArray& target_array, - const llvm_ir::IrArray& lhs_array, - const llvm_ir::IrArray& rhs_array, - const llvm_ir::IrArray* addend_array, - llvm::Value* executable_run_options_value, - llvm::IRBuilder<>* b, - const HloModuleConfig& hlo_module_config, - const TargetMachineFeatures& target_machine_features) +DotOpEmitter::DotOpEmitter( + DotInfo dot_info, string dot_hlo_name, const llvm_ir::IrArray& target_array, + const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, + const llvm_ir::IrArray* addend_array, + llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b, + mlir::MLIRContext* mlir_context, const HloModuleConfig& hlo_module_config, + const TargetMachineFeatures& target_machine_features) : dot_info_(std::move(dot_info)), dot_hlo_name_(std::move(dot_hlo_name)), target_array_(target_array), @@ -216,9 +230,36 @@ DotOpEmitter::DotOpEmitter(DotInfo dot_info, string dot_hlo_name, addend_array_(addend_array), executable_run_options_value_(executable_run_options_value), b_(b), + mlir_context_(mlir_context), hlo_module_config_(hlo_module_config), target_machine_features_(target_machine_features) {} +Status DotOpEmitter::EmitLinalgMatmul() { + Shape operand_shapes[] = {dot_info_.lhs_shape, dot_info_.rhs_shape}; + llvm::Value* operand_ptrs[] = {lhs_array_.GetBasePointer(), + rhs_array_.GetBasePointer()}; + llvm::Value* target_ptr = target_array_.GetBasePointer(); + + // Zero out the output buffer. + int64 size_bytes = ShapeUtil::ByteSizeOf(dot_info_.result_shape); + b_->CreateMemSet(target_ptr, b_->getInt8(0), /*Size=*/size_bytes, + /*Align=*/llvm::MaybeAlign(1)); + + std::string name = + absl::StrCat("linalgMatMul_", dot_info_.result_shape.ToString(true), "_", + dot_info_.lhs_shape.ToString(true), "_", + dot_info_.rhs_shape.ToString(true)); + return EmitMlirFuncAndCall( + mlir_context_, b_, dot_info_.result_shape, operand_shapes, target_ptr, + operand_ptrs, name, [&](mlir::OpBuilder* builder, mlir::FuncOp function) { + mlir::edsc::ScopedContext scope(*builder, function.getLoc()); + mlir::Value a = function.getArgument(0), b = function.getArgument(1), + c = function.getArgument(2); + mlir::edsc::intrinsics::linalg_matmul(b, c, a); + mlir::edsc::intrinsics::std_ret(); + }); +} + void DotOpEmitter::EmitTiledLlvmIrGemm() { PrimitiveType primitive_type = dot_info_.result_shape.element_type(); MatMultDims mat_mult_dims = GetMatMultDims(); @@ -418,6 +459,9 @@ Status DotOpEmitter::Emit() { EmitTiledLlvmIrGemm(); return Status::OK(); + case DotImplementationStrategy::kLinalgMatmul: + return EmitLinalgMatmul(); + case DotImplementationStrategy::kEigen: return EmitCallToRuntime(); } @@ -886,9 +930,12 @@ DotImplementationStrategy GetDotImplementationStrategy( } if (IsAlignedGemm(dot_info, target_machine_features)) { - return CanEmitTiledLlvmIrGemm(config, dot_info, target_machine_features) - ? DotImplementationStrategy::kTiledLlvmIrGemm - : DotImplementationStrategy::kEigen; + if (CanEmitTiledLlvmIrGemm(config, dot_info, target_machine_features)) { + return options::UseLinalgForDot(config) + ? DotImplementationStrategy::kLinalgMatmul + : DotImplementationStrategy::kTiledLlvmIrGemm; + } + return DotImplementationStrategy::kEigen; } return DotImplementationStrategy::kNaiveLlvmIr; @@ -899,15 +946,15 @@ Status EmitNonBatchDotOperation( const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b, - const HloModuleConfig& hlo_module_config, + mlir::MLIRContext* mlir_context, const HloModuleConfig& hlo_module_config, const TargetMachineFeatures& target_machine_features) { PrimitiveType type = target_array.GetShape().element_type(); TF_RET_CHECK(S32 == type || F16 == type || F32 == type || F64 == type || C64 == type || C128 == type); DotOpEmitter dot_emitter(std::move(dot_info), std::move(hlo_name), target_array, lhs_array, rhs_array, addend_array, - executable_run_options_value, b, hlo_module_config, - target_machine_features); + executable_run_options_value, b, mlir_context, + hlo_module_config, target_machine_features); return dot_emitter.Emit(); } @@ -981,7 +1028,7 @@ Status EmitBatchDotOperation( const HloInstruction& dot, const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b, - const HloModuleConfig& hlo_module_config, + mlir::MLIRContext* mlir_context, const HloModuleConfig& hlo_module_config, const TargetMachineFeatures& target_machine_features) { TF_RETURN_IF_ERROR(ValidateDotDimensionNumbers(dot.dot_dimension_numbers())); @@ -1039,7 +1086,7 @@ Status EmitBatchDotOperation( // Emit the inner non-batch dot operation. return EmitNonBatchDotOperation( dot_info, dot.name(), target_slice, lhs_slice, rhs_slice, nullptr, - executable_run_options_value, b, hlo_module_config, + executable_run_options_value, b, mlir_context, hlo_module_config, target_machine_features); }); } @@ -1089,7 +1136,7 @@ Status EmitDotOperation(const HloInstruction& dot, const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, llvm::Value* executable_run_options_value, - llvm::IRBuilder<>* b, + llvm::IRBuilder<>* b, mlir::MLIRContext* mlir_context, const HloModuleConfig& hlo_module_config, const TargetMachineFeatures& target_machine_features) { // This routine assumes that the dot operation is not in a parallelized @@ -1099,13 +1146,13 @@ Status EmitDotOperation(const HloInstruction& dot, if (IsBatchDot(dot)) { TF_RET_CHECK(addend_array == nullptr); return EmitBatchDotOperation(dot, target_array, lhs_array, rhs_array, - executable_run_options_value, b, + executable_run_options_value, b, mlir_context, hlo_module_config, target_machine_features); } return EmitNonBatchDotOperation(DotInfo(dot), dot.name(), target_array, lhs_array, rhs_array, addend_array, - executable_run_options_value, b, + executable_run_options_value, b, mlir_context, hlo_module_config, target_machine_features); } } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h index 105bd3005c8..d9cf8a2036b 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h @@ -18,6 +18,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "llvm/IR/IRBuilder.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "tensorflow/compiler/xla/service/cpu/cpu_options.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -63,7 +64,7 @@ Status EmitDotOperation(const HloInstruction& dot, const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, llvm::Value* executable_run_options_value, - llvm::IRBuilder<>* b, + llvm::IRBuilder<>* b, mlir::MLIRContext* mlir_context, const HloModuleConfig& hlo_module_config, const TargetMachineFeatures& target_machine_features); } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 70dde919afb..043ad68a196 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -89,8 +89,8 @@ using llvm_ir::SetToFirstInsertPoint; namespace cpu { IrEmitter::IrEmitter( - const HloModule& hlo_module, const BufferAssignment& assignment, - llvm::Module* llvm_module, + mlir::MLIRContext* mlir_context, const HloModule& hlo_module, + const BufferAssignment& assignment, llvm::Module* llvm_module, std::unordered_map instruction_to_profile_idx, std::unordered_map computation_to_profile_idx, const TargetMachineFeatures* target_machine_features, @@ -99,6 +99,7 @@ IrEmitter::IrEmitter( module_(llvm_module), arch_type_(llvm::Triple(llvm_module->getTargetTriple()).getArch()), b_(llvm_module->getContext()), + mlir_context_(mlir_context), instruction_to_profile_idx_(std::move(instruction_to_profile_idx)), computation_to_profile_idx_(std::move(computation_to_profile_idx)), alias_analysis_(hlo_module, assignment, &llvm_module->getContext()), @@ -898,7 +899,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { // Dot operation is complicated so we delegate to a helper class. return EmitDotOperation(*dot, target_array, lhs_array, rhs_array, /*addend_array=*/nullptr, - GetExecutableRunOptionsArgument(), &b_, + GetExecutableRunOptionsArgument(), &b_, mlir_context_, hlo_module_config_, target_machine_features_); } @@ -2305,10 +2306,10 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { llvm_ir::IrArray addend_array( GetIrArrayFor(fusion->operand(addend_param_number))); - TF_RETURN_IF_ERROR( - EmitDotOperation(*dot, target_array, lhs_array, rhs_array, - &addend_array, GetExecutableRunOptionsArgument(), &b_, - hlo_module_config_, target_machine_features_)); + TF_RETURN_IF_ERROR(EmitDotOperation( + *dot, target_array, lhs_array, rhs_array, &addend_array, + GetExecutableRunOptionsArgument(), &b_, mlir_context_, + hlo_module_config_, target_machine_features_)); return Status::OK(); } else { return Unimplemented("Fusion kind not implemented on CPU"); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 9b0d11e9f3f..661785153d0 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMITTER_H_ #include + #include #include #include @@ -32,6 +33,7 @@ limitations under the License. #include "llvm/IR/Module.h" #include "llvm/IR/Value.h" #include "llvm/Target/TargetMachine.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/cpu/ir_function.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" @@ -69,14 +71,16 @@ class IrEmitter : public DfsHloVisitorWithDefault, // hlo_module: the HLO module we are emitting IR for. // assignment: a BufferAssignment from which we know which buffers are used by // the HLO nodes. - // llvm_module: the LLVM module to emit IR into. + // mlir_context: the MLIR context used for IR emission. + // llvm_module: the LLVM module to emit IR into. It's built using the LLVM + // context inside of mlir_context. // instruction_to_profile_idx: the mapping from HLO instructions to their // index in the profiling array. // computation_to_profile_idx: the mapping from HLO computations to their // index in the profiling array. // emit_code_for_msan: whether emitted code should be compatible with msan. - IrEmitter(const HloModule& hlo_module, const BufferAssignment& assignment, - llvm::Module* llvm_module, + IrEmitter(mlir::MLIRContext* mlir_context, const HloModule& hlo_module, + const BufferAssignment& assignment, llvm::Module* llvm_module, std::unordered_map instruction_to_profile_idx, std::unordered_map @@ -442,6 +446,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, // module's function list). std::unique_ptr compute_function_; llvm::IRBuilder<> b_; + mlir::MLIRContext* mlir_context_; // The buffer allocation slice for the root of the computation being compiled. // Only relevant for thread local computations. diff --git a/tensorflow/compiler/xla/service/cpu/mlir_emitter.cc b/tensorflow/compiler/xla/service/cpu/mlir_emitter.cc new file mode 100644 index 00000000000..e7d52c288d5 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/mlir_emitter.cc @@ -0,0 +1,132 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/mlir_emitter.h" + +#include "llvm/Linker/Linker.h" +#include "llvm/Transforms/IPO/Internalize.h" +#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" // from @llvm-project +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" // from @llvm-project +#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" // from @llvm-project +#include "mlir/Dialect/Linalg/Passes.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Target/LLVMIR.h" // from @llvm-project +#include "tensorflow/compiler/mlir/xla/hlo_utils.h" + +namespace xla { +namespace cpu { +namespace { + +// Lower an MLIR module to an LLVM module. +std::unique_ptr MakeLLVMModule(mlir::OwningModuleRef module) { + mlir::PassManager manager(module->getContext()); + manager.addPass(mlir::createConvertLinalgToLoopsPass()); + manager.addPass(mlir::createConvertLinalgToLLVMPass()); + manager.addPass(mlir::createConvertVectorToLLVMPass()); + manager.addPass(mlir::createLowerToLLVMPass()); + CHECK(succeeded(manager.run(*module))); + return mlir::translateModuleToLLVMIR(*module); +} + +// Get arguments to pass a memref to an mlir function. +void BuildViewForBuffer(llvm::SmallVectorImpl *args, + llvm::IRBuilder<> *b, const Shape &opShape, + llvm::Value *op_val) { + llvm::Type *ty = op_val->getType(); + while (auto aty = llvm::dyn_cast( + llvm::cast(ty)->getElementType())) { + ty = aty->getElementType()->getPointerTo(); + } + op_val = b->CreateBitCast(op_val, ty); + + args->push_back(op_val); // Allocated pointer. + args->push_back(op_val); // Aligned pointer. + args->push_back(b->getInt64(0)); // Offset. + + // Sizes. + for (int64 dim : opShape.dimensions()) { + args->push_back(b->getInt64(dim)); + } + + int64_t accumulated_stride = 1; + llvm::SmallVector strides(opShape.rank(), 1); + for (int64 dim : LayoutUtil::MinorToMajor(opShape)) { + strides[dim] = accumulated_stride; + accumulated_stride *= opShape.dimensions(dim); + } + + // Strides. + for (int64 stride : strides) { + args->push_back(b->getInt64(stride)); + } +} +} // namespace + +Status EmitMlirFuncAndCall( + mlir::MLIRContext *context, llvm::IRBuilder<> *b, const Shape &result_shape, + llvm::ArrayRef operand_shapes, llvm::Value *result_ptr, + llvm::ArrayRef operand_ptrs, llvm::StringRef func_name, + llvm::function_ref emitter) { + llvm::Module *llvm_module = b->GetInsertBlock()->getParent()->getParent(); + mlir::Builder mlir_builder(context); + + // Get memref types for the inputs and output. + TF_ASSIGN_OR_RETURN(mlir::Type ret_memref, ConvertTensorShapeToMemRefType( + result_shape, mlir_builder)); + std::vector operand_types = {ret_memref}; + for (int i = 0; i != operand_shapes.size(); ++i) { + TF_ASSIGN_OR_RETURN( + mlir::Type op_memref, + ConvertTensorShapeToMemRefType(operand_shapes[i], mlir_builder)); + operand_types.push_back(op_memref); + } + + // Create the function an call the emission callback. + mlir::Location loc = mlir::UnknownLoc::get(context); + auto function = mlir::FuncOp::create( + loc, func_name, mlir::FunctionType::get(operand_types, {}, context)); + function.addEntryBlock(); + mlir::OwningModuleRef mlir_module = mlir::ModuleOp::create(loc); + mlir_module->push_back(function); + mlir::OpBuilder op_builder(&function.getBody()); + emitter(&op_builder, function); + + // Now link it all into the main LLVM module. + auto mlir_llvm_module = MakeLLVMModule(std::move(mlir_module)); + mlir_llvm_module->setDataLayout(llvm_module->getDataLayout()); + llvm::Linker::linkModules( + *llvm_module, std::move(mlir_llvm_module), llvm::Linker::None, + [](llvm::Module &M, const llvm::StringSet<> &GVS) { + llvm::internalizeModule(M, [&GVS](const llvm::GlobalValue &GV) { + return !GV.hasName() || (GVS.count(GV.getName()) == 0); + }); + }); + + // And leave behind a call to the function generated by MLIR. + llvm::Function *func = llvm_module->getFunction(func_name); + llvm::SmallVector op_vals; + BuildViewForBuffer(&op_vals, b, result_shape, result_ptr); + for (int i = 0; i != operand_shapes.size(); ++i) { + BuildViewForBuffer(&op_vals, b, operand_shapes[i], operand_ptrs[i]); + } + b->CreateCall(func, op_vals); + + return Status::OK(); +} + +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/mlir_emitter.h b/tensorflow/compiler/xla/service/cpu/mlir_emitter.h new file mode 100644 index 00000000000..bc0741e851a --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/mlir_emitter.h @@ -0,0 +1,43 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_MLIR_EMITTER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_MLIR_EMITTER_H_ + +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Value.h" +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/status.h" + +namespace xla { +namespace cpu { + +// Create a new MLIR function with the name `func_name`, populate it with +// `emitter` and create a call, passing it the buffers defined by +// resultShape/resultPtr and operandShapes/operandPtrs. The function is added to +// the LLVM module at `b`s insertion point. +Status EmitMlirFuncAndCall( + mlir::MLIRContext *context, llvm::IRBuilder<> *b, const Shape &result_shape, + llvm::ArrayRef operand_shapes, llvm::Value *result_ptr, + llvm::ArrayRef operand_ptrs, llvm::StringRef func_name, + llvm::function_ref emitter); + +} // namespace cpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_MLIR_EMITTER_H_