[XLA:CPU] Plumb through a minimal emitter for matmuls using the mlir linalg dialect
This is just the most basic lowering and will generate linalg.matmul for small matmuls and then convert to loops. The result is fairly slow, but we can iterate on that. To make XLA use it set XLA_FLAGS=--xla_backend_extra_options=xla_use_linalg_for_dot PiperOrigin-RevId: 312471829 Change-Id: I213d1f6114671bc595ac1647d3689736ee8f56f4
This commit is contained in:
parent
cdd4e5e918
commit
01b38cd7c6
|
@ -118,6 +118,9 @@ cc_library(
|
||||||
":target_machine_features",
|
":target_machine_features",
|
||||||
"@com_google_absl//absl/base",
|
"@com_google_absl//absl/base",
|
||||||
"@com_google_absl//absl/types:span",
|
"@com_google_absl//absl/types:span",
|
||||||
|
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
|
||||||
|
"@llvm-project//mlir:ExecutionEngineUtils",
|
||||||
|
"@llvm-project//mlir:LLVMDialect",
|
||||||
"//tensorflow/compiler/xla/service:copy_insertion",
|
"//tensorflow/compiler/xla/service:copy_insertion",
|
||||||
"//tensorflow/compiler/xla/service:hlo_casting_utils",
|
"//tensorflow/compiler/xla/service:hlo_casting_utils",
|
||||||
"//tensorflow/compiler/xla/service:dump",
|
"//tensorflow/compiler/xla/service:dump",
|
||||||
|
@ -366,6 +369,7 @@ cc_library(
|
||||||
"@llvm-project//llvm:core",
|
"@llvm-project//llvm:core",
|
||||||
"@llvm-project//llvm:support",
|
"@llvm-project//llvm:support",
|
||||||
"@llvm-project//llvm:target",
|
"@llvm-project//llvm:target",
|
||||||
|
"@llvm-project//mlir:IR",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -456,6 +460,7 @@ cc_library(
|
||||||
":cpu_options",
|
":cpu_options",
|
||||||
":cpu_runtime",
|
":cpu_runtime",
|
||||||
":ir_emission_utils",
|
":ir_emission_utils",
|
||||||
|
":mlir_emitter",
|
||||||
":target_machine_features",
|
":target_machine_features",
|
||||||
":tiled_dot_emitter",
|
":tiled_dot_emitter",
|
||||||
":vector_support_library",
|
":vector_support_library",
|
||||||
|
@ -474,6 +479,10 @@ cc_library(
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@llvm-project//llvm:core",
|
"@llvm-project//llvm:core",
|
||||||
|
"@llvm-project//mlir:EDSC",
|
||||||
|
"@llvm-project//mlir:IR",
|
||||||
|
"@llvm-project//mlir:LinalgOps",
|
||||||
|
"@llvm-project//mlir:StandardOps",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1070,3 +1079,24 @@ tf_cc_test(
|
||||||
"@llvm-project//llvm:target",
|
"@llvm-project//llvm:target",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "mlir_emitter",
|
||||||
|
srcs = ["mlir_emitter.cc"],
|
||||||
|
hdrs = ["mlir_emitter.h"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/compiler/mlir/xla:hlo_utils",
|
||||||
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
|
"//tensorflow/compiler/xla:status",
|
||||||
|
"@llvm-project//llvm:core",
|
||||||
|
"@llvm-project//llvm:ipo",
|
||||||
|
"@llvm-project//llvm:linker",
|
||||||
|
"@llvm-project//mlir:IR",
|
||||||
|
"@llvm-project//mlir:LLVMTransforms",
|
||||||
|
"@llvm-project//mlir:LinalgToLLVM",
|
||||||
|
"@llvm-project//mlir:LinalgTransforms",
|
||||||
|
"@llvm-project//mlir:Pass",
|
||||||
|
"@llvm-project//mlir:TargetLLVMIR",
|
||||||
|
"@llvm-project//mlir:VectorToLLVM",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
|
@ -42,6 +42,8 @@ limitations under the License.
|
||||||
#include "llvm/Support/TargetSelect.h"
|
#include "llvm/Support/TargetSelect.h"
|
||||||
#include "llvm/Target/TargetMachine.h"
|
#include "llvm/Target/TargetMachine.h"
|
||||||
#include "llvm/Target/TargetOptions.h"
|
#include "llvm/Target/TargetOptions.h"
|
||||||
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project
|
||||||
|
#include "mlir/InitAllDialects.h" // from @llvm-project
|
||||||
#include "tensorflow/compiler/xla/cpu_function_runtime.h"
|
#include "tensorflow/compiler/xla/cpu_function_runtime.h"
|
||||||
#include "tensorflow/compiler/xla/literal.h"
|
#include "tensorflow/compiler/xla/literal.h"
|
||||||
#include "tensorflow/compiler/xla/map_util.h"
|
#include "tensorflow/compiler/xla/map_util.h"
|
||||||
|
@ -158,6 +160,8 @@ CpuCompiler::CpuCompiler() {
|
||||||
// Initialize LLVM's MC layer for the native target.
|
// Initialize LLVM's MC layer for the native target.
|
||||||
llvm::InitializeNativeTarget();
|
llvm::InitializeNativeTarget();
|
||||||
llvm::InitializeNativeTargetAsmPrinter();
|
llvm::InitializeNativeTargetAsmPrinter();
|
||||||
|
|
||||||
|
mlir::registerAllDialects();
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -606,9 +610,11 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
|
||||||
user_post_optimization_hook_);
|
user_post_optimization_hook_);
|
||||||
|
|
||||||
// Compile must be thread-safe so create a new LLVM context for the module.
|
// Compile must be thread-safe so create a new LLVM context for the module.
|
||||||
auto llvm_context = absl::make_unique<llvm::LLVMContext>();
|
mlir::MLIRContext mlir_context;
|
||||||
auto llvm_module =
|
auto llvm_module = absl::make_unique<llvm::Module>(
|
||||||
absl::make_unique<llvm::Module>("__compute_module", *llvm_context);
|
"__compute_module",
|
||||||
|
mlir_context.getRegisteredDialect<mlir::LLVM::LLVMDialect>()
|
||||||
|
->getLLVMContext());
|
||||||
|
|
||||||
auto jit = absl::make_unique<SimpleOrcJIT>(
|
auto jit = absl::make_unique<SimpleOrcJIT>(
|
||||||
CompilerTargetOptions(module->config()),
|
CompilerTargetOptions(module->config()),
|
||||||
|
@ -662,7 +668,7 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
|
||||||
// before a caller computation.
|
// before a caller computation.
|
||||||
|
|
||||||
LLVMTargetMachineFeatures target_machine_features(jit->target_machine());
|
LLVMTargetMachineFeatures target_machine_features(jit->target_machine());
|
||||||
IrEmitter ir_emitter(*module, *assignment, llvm_module.get(),
|
IrEmitter ir_emitter(&mlir_context, *module, *assignment, llvm_module.get(),
|
||||||
std::move(instruction_to_profile_idx),
|
std::move(instruction_to_profile_idx),
|
||||||
std::move(computation_to_profile_idx),
|
std::move(computation_to_profile_idx),
|
||||||
&target_machine_features,
|
&target_machine_features,
|
||||||
|
@ -816,8 +822,11 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
|
||||||
opt_level));
|
opt_level));
|
||||||
|
|
||||||
// Compile must be thread-safe so create a new LLVM context for the module.
|
// Compile must be thread-safe so create a new LLVM context for the module.
|
||||||
llvm::LLVMContext llvm_context;
|
mlir::MLIRContext mlir_context;
|
||||||
llvm::Module llvm_module("__compute_module", llvm_context);
|
llvm::Module llvm_module(
|
||||||
|
"__compute_module",
|
||||||
|
mlir_context.getRegisteredDialect<mlir::LLVM::LLVMDialect>()
|
||||||
|
->getLLVMContext());
|
||||||
llvm_module.setDataLayout(target_machine->createDataLayout());
|
llvm_module.setDataLayout(target_machine->createDataLayout());
|
||||||
llvm_module.setTargetTriple(triple.getTriple());
|
llvm_module.setTargetTriple(triple.getTriple());
|
||||||
if (pic_level != llvm::PICLevel::NotPIC) {
|
if (pic_level != llvm::PICLevel::NotPIC) {
|
||||||
|
@ -866,7 +875,7 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
|
||||||
}
|
}
|
||||||
|
|
||||||
LLVMTargetMachineFeatures target_machine_features(target_machine.get());
|
LLVMTargetMachineFeatures target_machine_features(target_machine.get());
|
||||||
IrEmitter ir_emitter(*module, *assignment, &llvm_module,
|
IrEmitter ir_emitter(&mlir_context, *module, *assignment, &llvm_module,
|
||||||
std::move(instruction_to_profile_idx),
|
std::move(instruction_to_profile_idx),
|
||||||
std::move(computation_to_profile_idx),
|
std::move(computation_to_profile_idx),
|
||||||
&target_machine_features,
|
&target_machine_features,
|
||||||
|
|
|
@ -25,6 +25,7 @@ const char* const kXlaOptimizeForSizeCpuOption = "xla_cpu_optimize_for_size";
|
||||||
const char* const kLlvmIrDotTilingFactor = "xla_llvm_dot_tiling_factor";
|
const char* const kLlvmIrDotTilingFactor = "xla_llvm_dot_tiling_factor";
|
||||||
const char* const kXlaForceEnableExperimentalLlvmIrGemm =
|
const char* const kXlaForceEnableExperimentalLlvmIrGemm =
|
||||||
"xla_force_enable_experimental_llvm_ir_gemm";
|
"xla_force_enable_experimental_llvm_ir_gemm";
|
||||||
|
const char* const kXlaUseLinalgForDot = "xla_use_linalg_for_dot";
|
||||||
const char* const kLlvmIrGemmTileSize = "xla_llvm_ir_gemm_tile_size";
|
const char* const kLlvmIrGemmTileSize = "xla_llvm_ir_gemm_tile_size";
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -63,6 +64,12 @@ bool ForceEnableExperimentalLlvmIrGemm(const HloModuleConfig& config) {
|
||||||
return extra_options_map.count(kXlaForceEnableExperimentalLlvmIrGemm) > 0;
|
return extra_options_map.count(kXlaForceEnableExperimentalLlvmIrGemm) > 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool UseLinalgForDot(const HloModuleConfig& config) {
|
||||||
|
const auto& extra_options_map =
|
||||||
|
config.debug_options().xla_backend_extra_options();
|
||||||
|
return extra_options_map.count(kXlaUseLinalgForDot) > 0;
|
||||||
|
}
|
||||||
|
|
||||||
static absl::string_view RemoveSuffix(absl::string_view str,
|
static absl::string_view RemoveSuffix(absl::string_view str,
|
||||||
absl::string_view suffix) {
|
absl::string_view suffix) {
|
||||||
CHECK_GE(str.size(), suffix.size());
|
CHECK_GE(str.size(), suffix.size());
|
||||||
|
|
|
@ -27,6 +27,7 @@ namespace options {
|
||||||
bool OptimizeForSizeRequested(const HloModuleConfig& config);
|
bool OptimizeForSizeRequested(const HloModuleConfig& config);
|
||||||
bool VectorizedReduceDisabled(const HloModuleConfig& config);
|
bool VectorizedReduceDisabled(const HloModuleConfig& config);
|
||||||
bool ForceEnableExperimentalLlvmIrGemm(const HloModuleConfig& config);
|
bool ForceEnableExperimentalLlvmIrGemm(const HloModuleConfig& config);
|
||||||
|
bool UseLinalgForDot(const HloModuleConfig& config);
|
||||||
absl::optional<int64> LlvmIrGemvTilingFactor(const HloModuleConfig& config);
|
absl::optional<int64> LlvmIrGemvTilingFactor(const HloModuleConfig& config);
|
||||||
absl::optional<std::tuple<int64, int64, int64>> LlvmIrGemmTileSize(
|
absl::optional<std::tuple<int64, int64, int64>> LlvmIrGemmTileSize(
|
||||||
const HloModuleConfig& config);
|
const HloModuleConfig& config);
|
||||||
|
|
|
@ -23,8 +23,17 @@ limitations under the License.
|
||||||
#include "llvm/IR/Instructions.h"
|
#include "llvm/IR/Instructions.h"
|
||||||
#include "llvm/IR/Module.h"
|
#include "llvm/IR/Module.h"
|
||||||
#include "llvm/IR/Value.h"
|
#include "llvm/IR/Value.h"
|
||||||
|
#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" // from @llvm-project
|
||||||
|
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" // from @llvm-project
|
||||||
|
#include "mlir/EDSC/Builders.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/Function.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/Value.h" // from @llvm-project
|
||||||
|
#include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
|
||||||
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
|
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
|
||||||
#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
|
#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/cpu/mlir_emitter.h"
|
||||||
#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
|
#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
|
||||||
#include "tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h"
|
#include "tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h"
|
||||||
#include "tensorflow/compiler/xla/service/cpu/vector_support_library.h"
|
#include "tensorflow/compiler/xla/service/cpu/vector_support_library.h"
|
||||||
|
@ -89,6 +98,9 @@ enum class DotImplementationStrategy {
|
||||||
// and the output have to be row major.
|
// and the output have to be row major.
|
||||||
kTiledLlvmIrGemm,
|
kTiledLlvmIrGemm,
|
||||||
|
|
||||||
|
// The dot operation is lowered into linalg.matmul op and lowered to LLVM IR.
|
||||||
|
kLinalgMatmul,
|
||||||
|
|
||||||
// The dot operation is lowered into a call into an Eigen routine. No fusions
|
// The dot operation is lowered into a call into an Eigen routine. No fusions
|
||||||
// are supported today. The two inputs and the output have to be row major.
|
// are supported today. The two inputs and the output have to be row major.
|
||||||
// However, we do allow transposing either the LHS or the RHS as part of the
|
// However, we do allow transposing either the LHS or the RHS as part of the
|
||||||
|
@ -112,7 +124,7 @@ class DotOpEmitter {
|
||||||
const llvm_ir::IrArray& rhs_array,
|
const llvm_ir::IrArray& rhs_array,
|
||||||
const llvm_ir::IrArray* addend_array,
|
const llvm_ir::IrArray* addend_array,
|
||||||
llvm::Value* executable_run_options_value,
|
llvm::Value* executable_run_options_value,
|
||||||
llvm::IRBuilder<>* b,
|
llvm::IRBuilder<>* b, mlir::MLIRContext* mlir_context,
|
||||||
const HloModuleConfig& hlo_module_config,
|
const HloModuleConfig& hlo_module_config,
|
||||||
const TargetMachineFeatures& target_machine_features);
|
const TargetMachineFeatures& target_machine_features);
|
||||||
|
|
||||||
|
@ -163,6 +175,9 @@ class DotOpEmitter {
|
||||||
// Lowers the dot operation as a tiled Matrix*Matrix loop.
|
// Lowers the dot operation as a tiled Matrix*Matrix loop.
|
||||||
void EmitTiledLlvmIrGemm();
|
void EmitTiledLlvmIrGemm();
|
||||||
|
|
||||||
|
// Lowers the dot operation through MLIR's linalg.matmul.
|
||||||
|
Status EmitLinalgMatmul();
|
||||||
|
|
||||||
// Lowers the dot operation as a naive nested loop that computes the result
|
// Lowers the dot operation as a naive nested loop that computes the result
|
||||||
// one element at a time.
|
// one element at a time.
|
||||||
void EmitNaiveLlvmIrGemm();
|
void EmitNaiveLlvmIrGemm();
|
||||||
|
@ -194,19 +209,18 @@ class DotOpEmitter {
|
||||||
const llvm_ir::IrArray* addend_array_;
|
const llvm_ir::IrArray* addend_array_;
|
||||||
llvm::Value* executable_run_options_value_;
|
llvm::Value* executable_run_options_value_;
|
||||||
llvm::IRBuilder<>* b_;
|
llvm::IRBuilder<>* b_;
|
||||||
|
mlir::MLIRContext* mlir_context_;
|
||||||
const HloModuleConfig& hlo_module_config_;
|
const HloModuleConfig& hlo_module_config_;
|
||||||
const TargetMachineFeatures& target_machine_features_;
|
const TargetMachineFeatures& target_machine_features_;
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
DotOpEmitter::DotOpEmitter(DotInfo dot_info, string dot_hlo_name,
|
DotOpEmitter::DotOpEmitter(
|
||||||
const llvm_ir::IrArray& target_array,
|
DotInfo dot_info, string dot_hlo_name, const llvm_ir::IrArray& target_array,
|
||||||
const llvm_ir::IrArray& lhs_array,
|
const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array,
|
||||||
const llvm_ir::IrArray& rhs_array,
|
|
||||||
const llvm_ir::IrArray* addend_array,
|
const llvm_ir::IrArray* addend_array,
|
||||||
llvm::Value* executable_run_options_value,
|
llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b,
|
||||||
llvm::IRBuilder<>* b,
|
mlir::MLIRContext* mlir_context, const HloModuleConfig& hlo_module_config,
|
||||||
const HloModuleConfig& hlo_module_config,
|
|
||||||
const TargetMachineFeatures& target_machine_features)
|
const TargetMachineFeatures& target_machine_features)
|
||||||
: dot_info_(std::move(dot_info)),
|
: dot_info_(std::move(dot_info)),
|
||||||
dot_hlo_name_(std::move(dot_hlo_name)),
|
dot_hlo_name_(std::move(dot_hlo_name)),
|
||||||
|
@ -216,9 +230,36 @@ DotOpEmitter::DotOpEmitter(DotInfo dot_info, string dot_hlo_name,
|
||||||
addend_array_(addend_array),
|
addend_array_(addend_array),
|
||||||
executable_run_options_value_(executable_run_options_value),
|
executable_run_options_value_(executable_run_options_value),
|
||||||
b_(b),
|
b_(b),
|
||||||
|
mlir_context_(mlir_context),
|
||||||
hlo_module_config_(hlo_module_config),
|
hlo_module_config_(hlo_module_config),
|
||||||
target_machine_features_(target_machine_features) {}
|
target_machine_features_(target_machine_features) {}
|
||||||
|
|
||||||
|
Status DotOpEmitter::EmitLinalgMatmul() {
|
||||||
|
Shape operand_shapes[] = {dot_info_.lhs_shape, dot_info_.rhs_shape};
|
||||||
|
llvm::Value* operand_ptrs[] = {lhs_array_.GetBasePointer(),
|
||||||
|
rhs_array_.GetBasePointer()};
|
||||||
|
llvm::Value* target_ptr = target_array_.GetBasePointer();
|
||||||
|
|
||||||
|
// Zero out the output buffer.
|
||||||
|
int64 size_bytes = ShapeUtil::ByteSizeOf(dot_info_.result_shape);
|
||||||
|
b_->CreateMemSet(target_ptr, b_->getInt8(0), /*Size=*/size_bytes,
|
||||||
|
/*Align=*/llvm::MaybeAlign(1));
|
||||||
|
|
||||||
|
std::string name =
|
||||||
|
absl::StrCat("linalgMatMul_", dot_info_.result_shape.ToString(true), "_",
|
||||||
|
dot_info_.lhs_shape.ToString(true), "_",
|
||||||
|
dot_info_.rhs_shape.ToString(true));
|
||||||
|
return EmitMlirFuncAndCall(
|
||||||
|
mlir_context_, b_, dot_info_.result_shape, operand_shapes, target_ptr,
|
||||||
|
operand_ptrs, name, [&](mlir::OpBuilder* builder, mlir::FuncOp function) {
|
||||||
|
mlir::edsc::ScopedContext scope(*builder, function.getLoc());
|
||||||
|
mlir::Value a = function.getArgument(0), b = function.getArgument(1),
|
||||||
|
c = function.getArgument(2);
|
||||||
|
mlir::edsc::intrinsics::linalg_matmul(b, c, a);
|
||||||
|
mlir::edsc::intrinsics::std_ret();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
void DotOpEmitter::EmitTiledLlvmIrGemm() {
|
void DotOpEmitter::EmitTiledLlvmIrGemm() {
|
||||||
PrimitiveType primitive_type = dot_info_.result_shape.element_type();
|
PrimitiveType primitive_type = dot_info_.result_shape.element_type();
|
||||||
MatMultDims mat_mult_dims = GetMatMultDims();
|
MatMultDims mat_mult_dims = GetMatMultDims();
|
||||||
|
@ -418,6 +459,9 @@ Status DotOpEmitter::Emit() {
|
||||||
EmitTiledLlvmIrGemm();
|
EmitTiledLlvmIrGemm();
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|
||||||
|
case DotImplementationStrategy::kLinalgMatmul:
|
||||||
|
return EmitLinalgMatmul();
|
||||||
|
|
||||||
case DotImplementationStrategy::kEigen:
|
case DotImplementationStrategy::kEigen:
|
||||||
return EmitCallToRuntime();
|
return EmitCallToRuntime();
|
||||||
}
|
}
|
||||||
|
@ -886,9 +930,12 @@ DotImplementationStrategy GetDotImplementationStrategy(
|
||||||
}
|
}
|
||||||
|
|
||||||
if (IsAlignedGemm(dot_info, target_machine_features)) {
|
if (IsAlignedGemm(dot_info, target_machine_features)) {
|
||||||
return CanEmitTiledLlvmIrGemm(config, dot_info, target_machine_features)
|
if (CanEmitTiledLlvmIrGemm(config, dot_info, target_machine_features)) {
|
||||||
? DotImplementationStrategy::kTiledLlvmIrGemm
|
return options::UseLinalgForDot(config)
|
||||||
: DotImplementationStrategy::kEigen;
|
? DotImplementationStrategy::kLinalgMatmul
|
||||||
|
: DotImplementationStrategy::kTiledLlvmIrGemm;
|
||||||
|
}
|
||||||
|
return DotImplementationStrategy::kEigen;
|
||||||
}
|
}
|
||||||
|
|
||||||
return DotImplementationStrategy::kNaiveLlvmIr;
|
return DotImplementationStrategy::kNaiveLlvmIr;
|
||||||
|
@ -899,15 +946,15 @@ Status EmitNonBatchDotOperation(
|
||||||
const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array,
|
const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array,
|
||||||
const llvm_ir::IrArray* addend_array,
|
const llvm_ir::IrArray* addend_array,
|
||||||
llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b,
|
llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b,
|
||||||
const HloModuleConfig& hlo_module_config,
|
mlir::MLIRContext* mlir_context, const HloModuleConfig& hlo_module_config,
|
||||||
const TargetMachineFeatures& target_machine_features) {
|
const TargetMachineFeatures& target_machine_features) {
|
||||||
PrimitiveType type = target_array.GetShape().element_type();
|
PrimitiveType type = target_array.GetShape().element_type();
|
||||||
TF_RET_CHECK(S32 == type || F16 == type || F32 == type || F64 == type ||
|
TF_RET_CHECK(S32 == type || F16 == type || F32 == type || F64 == type ||
|
||||||
C64 == type || C128 == type);
|
C64 == type || C128 == type);
|
||||||
DotOpEmitter dot_emitter(std::move(dot_info), std::move(hlo_name),
|
DotOpEmitter dot_emitter(std::move(dot_info), std::move(hlo_name),
|
||||||
target_array, lhs_array, rhs_array, addend_array,
|
target_array, lhs_array, rhs_array, addend_array,
|
||||||
executable_run_options_value, b, hlo_module_config,
|
executable_run_options_value, b, mlir_context,
|
||||||
target_machine_features);
|
hlo_module_config, target_machine_features);
|
||||||
return dot_emitter.Emit();
|
return dot_emitter.Emit();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -981,7 +1028,7 @@ Status EmitBatchDotOperation(
|
||||||
const HloInstruction& dot, const llvm_ir::IrArray& target_array,
|
const HloInstruction& dot, const llvm_ir::IrArray& target_array,
|
||||||
const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array,
|
const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array,
|
||||||
llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b,
|
llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b,
|
||||||
const HloModuleConfig& hlo_module_config,
|
mlir::MLIRContext* mlir_context, const HloModuleConfig& hlo_module_config,
|
||||||
const TargetMachineFeatures& target_machine_features) {
|
const TargetMachineFeatures& target_machine_features) {
|
||||||
TF_RETURN_IF_ERROR(ValidateDotDimensionNumbers(dot.dot_dimension_numbers()));
|
TF_RETURN_IF_ERROR(ValidateDotDimensionNumbers(dot.dot_dimension_numbers()));
|
||||||
|
|
||||||
|
@ -1039,7 +1086,7 @@ Status EmitBatchDotOperation(
|
||||||
// Emit the inner non-batch dot operation.
|
// Emit the inner non-batch dot operation.
|
||||||
return EmitNonBatchDotOperation(
|
return EmitNonBatchDotOperation(
|
||||||
dot_info, dot.name(), target_slice, lhs_slice, rhs_slice, nullptr,
|
dot_info, dot.name(), target_slice, lhs_slice, rhs_slice, nullptr,
|
||||||
executable_run_options_value, b, hlo_module_config,
|
executable_run_options_value, b, mlir_context, hlo_module_config,
|
||||||
target_machine_features);
|
target_machine_features);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -1089,7 +1136,7 @@ Status EmitDotOperation(const HloInstruction& dot,
|
||||||
const llvm_ir::IrArray& rhs_array,
|
const llvm_ir::IrArray& rhs_array,
|
||||||
const llvm_ir::IrArray* addend_array,
|
const llvm_ir::IrArray* addend_array,
|
||||||
llvm::Value* executable_run_options_value,
|
llvm::Value* executable_run_options_value,
|
||||||
llvm::IRBuilder<>* b,
|
llvm::IRBuilder<>* b, mlir::MLIRContext* mlir_context,
|
||||||
const HloModuleConfig& hlo_module_config,
|
const HloModuleConfig& hlo_module_config,
|
||||||
const TargetMachineFeatures& target_machine_features) {
|
const TargetMachineFeatures& target_machine_features) {
|
||||||
// This routine assumes that the dot operation is not in a parallelized
|
// This routine assumes that the dot operation is not in a parallelized
|
||||||
|
@ -1099,13 +1146,13 @@ Status EmitDotOperation(const HloInstruction& dot,
|
||||||
if (IsBatchDot(dot)) {
|
if (IsBatchDot(dot)) {
|
||||||
TF_RET_CHECK(addend_array == nullptr);
|
TF_RET_CHECK(addend_array == nullptr);
|
||||||
return EmitBatchDotOperation(dot, target_array, lhs_array, rhs_array,
|
return EmitBatchDotOperation(dot, target_array, lhs_array, rhs_array,
|
||||||
executable_run_options_value, b,
|
executable_run_options_value, b, mlir_context,
|
||||||
hlo_module_config, target_machine_features);
|
hlo_module_config, target_machine_features);
|
||||||
}
|
}
|
||||||
|
|
||||||
return EmitNonBatchDotOperation(DotInfo(dot), dot.name(), target_array,
|
return EmitNonBatchDotOperation(DotInfo(dot), dot.name(), target_array,
|
||||||
lhs_array, rhs_array, addend_array,
|
lhs_array, rhs_array, addend_array,
|
||||||
executable_run_options_value, b,
|
executable_run_options_value, b, mlir_context,
|
||||||
hlo_module_config, target_machine_features);
|
hlo_module_config, target_machine_features);
|
||||||
}
|
}
|
||||||
} // namespace cpu
|
} // namespace cpu
|
||||||
|
|
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||||
|
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
#include "llvm/IR/IRBuilder.h"
|
#include "llvm/IR/IRBuilder.h"
|
||||||
|
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||||
#include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
|
#include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
|
||||||
#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
|
#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
|
@ -63,7 +64,7 @@ Status EmitDotOperation(const HloInstruction& dot,
|
||||||
const llvm_ir::IrArray& rhs_array,
|
const llvm_ir::IrArray& rhs_array,
|
||||||
const llvm_ir::IrArray* addend_array,
|
const llvm_ir::IrArray* addend_array,
|
||||||
llvm::Value* executable_run_options_value,
|
llvm::Value* executable_run_options_value,
|
||||||
llvm::IRBuilder<>* b,
|
llvm::IRBuilder<>* b, mlir::MLIRContext* mlir_context,
|
||||||
const HloModuleConfig& hlo_module_config,
|
const HloModuleConfig& hlo_module_config,
|
||||||
const TargetMachineFeatures& target_machine_features);
|
const TargetMachineFeatures& target_machine_features);
|
||||||
} // namespace cpu
|
} // namespace cpu
|
||||||
|
|
|
@ -89,8 +89,8 @@ using llvm_ir::SetToFirstInsertPoint;
|
||||||
namespace cpu {
|
namespace cpu {
|
||||||
|
|
||||||
IrEmitter::IrEmitter(
|
IrEmitter::IrEmitter(
|
||||||
const HloModule& hlo_module, const BufferAssignment& assignment,
|
mlir::MLIRContext* mlir_context, const HloModule& hlo_module,
|
||||||
llvm::Module* llvm_module,
|
const BufferAssignment& assignment, llvm::Module* llvm_module,
|
||||||
std::unordered_map<const HloInstruction*, int64> instruction_to_profile_idx,
|
std::unordered_map<const HloInstruction*, int64> instruction_to_profile_idx,
|
||||||
std::unordered_map<const HloComputation*, int64> computation_to_profile_idx,
|
std::unordered_map<const HloComputation*, int64> computation_to_profile_idx,
|
||||||
const TargetMachineFeatures* target_machine_features,
|
const TargetMachineFeatures* target_machine_features,
|
||||||
|
@ -99,6 +99,7 @@ IrEmitter::IrEmitter(
|
||||||
module_(llvm_module),
|
module_(llvm_module),
|
||||||
arch_type_(llvm::Triple(llvm_module->getTargetTriple()).getArch()),
|
arch_type_(llvm::Triple(llvm_module->getTargetTriple()).getArch()),
|
||||||
b_(llvm_module->getContext()),
|
b_(llvm_module->getContext()),
|
||||||
|
mlir_context_(mlir_context),
|
||||||
instruction_to_profile_idx_(std::move(instruction_to_profile_idx)),
|
instruction_to_profile_idx_(std::move(instruction_to_profile_idx)),
|
||||||
computation_to_profile_idx_(std::move(computation_to_profile_idx)),
|
computation_to_profile_idx_(std::move(computation_to_profile_idx)),
|
||||||
alias_analysis_(hlo_module, assignment, &llvm_module->getContext()),
|
alias_analysis_(hlo_module, assignment, &llvm_module->getContext()),
|
||||||
|
@ -898,7 +899,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot) {
|
||||||
// Dot operation is complicated so we delegate to a helper class.
|
// Dot operation is complicated so we delegate to a helper class.
|
||||||
return EmitDotOperation(*dot, target_array, lhs_array, rhs_array,
|
return EmitDotOperation(*dot, target_array, lhs_array, rhs_array,
|
||||||
/*addend_array=*/nullptr,
|
/*addend_array=*/nullptr,
|
||||||
GetExecutableRunOptionsArgument(), &b_,
|
GetExecutableRunOptionsArgument(), &b_, mlir_context_,
|
||||||
hlo_module_config_, target_machine_features_);
|
hlo_module_config_, target_machine_features_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2305,9 +2306,9 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) {
|
||||||
llvm_ir::IrArray addend_array(
|
llvm_ir::IrArray addend_array(
|
||||||
GetIrArrayFor(fusion->operand(addend_param_number)));
|
GetIrArrayFor(fusion->operand(addend_param_number)));
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(EmitDotOperation(
|
||||||
EmitDotOperation(*dot, target_array, lhs_array, rhs_array,
|
*dot, target_array, lhs_array, rhs_array, &addend_array,
|
||||||
&addend_array, GetExecutableRunOptionsArgument(), &b_,
|
GetExecutableRunOptionsArgument(), &b_, mlir_context_,
|
||||||
hlo_module_config_, target_machine_features_));
|
hlo_module_config_, target_machine_features_));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMITTER_H_
|
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMITTER_H_
|
||||||
|
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
|
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
@ -32,6 +33,7 @@ limitations under the License.
|
||||||
#include "llvm/IR/Module.h"
|
#include "llvm/IR/Module.h"
|
||||||
#include "llvm/IR/Value.h"
|
#include "llvm/IR/Value.h"
|
||||||
#include "llvm/Target/TargetMachine.h"
|
#include "llvm/Target/TargetMachine.h"
|
||||||
|
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||||
#include "tensorflow/compiler/xla/service/cpu/ir_function.h"
|
#include "tensorflow/compiler/xla/service/cpu/ir_function.h"
|
||||||
#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
|
#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
|
||||||
|
@ -69,14 +71,16 @@ class IrEmitter : public DfsHloVisitorWithDefault,
|
||||||
// hlo_module: the HLO module we are emitting IR for.
|
// hlo_module: the HLO module we are emitting IR for.
|
||||||
// assignment: a BufferAssignment from which we know which buffers are used by
|
// assignment: a BufferAssignment from which we know which buffers are used by
|
||||||
// the HLO nodes.
|
// the HLO nodes.
|
||||||
// llvm_module: the LLVM module to emit IR into.
|
// mlir_context: the MLIR context used for IR emission.
|
||||||
|
// llvm_module: the LLVM module to emit IR into. It's built using the LLVM
|
||||||
|
// context inside of mlir_context.
|
||||||
// instruction_to_profile_idx: the mapping from HLO instructions to their
|
// instruction_to_profile_idx: the mapping from HLO instructions to their
|
||||||
// index in the profiling array.
|
// index in the profiling array.
|
||||||
// computation_to_profile_idx: the mapping from HLO computations to their
|
// computation_to_profile_idx: the mapping from HLO computations to their
|
||||||
// index in the profiling array.
|
// index in the profiling array.
|
||||||
// emit_code_for_msan: whether emitted code should be compatible with msan.
|
// emit_code_for_msan: whether emitted code should be compatible with msan.
|
||||||
IrEmitter(const HloModule& hlo_module, const BufferAssignment& assignment,
|
IrEmitter(mlir::MLIRContext* mlir_context, const HloModule& hlo_module,
|
||||||
llvm::Module* llvm_module,
|
const BufferAssignment& assignment, llvm::Module* llvm_module,
|
||||||
std::unordered_map<const HloInstruction*, int64>
|
std::unordered_map<const HloInstruction*, int64>
|
||||||
instruction_to_profile_idx,
|
instruction_to_profile_idx,
|
||||||
std::unordered_map<const HloComputation*, int64>
|
std::unordered_map<const HloComputation*, int64>
|
||||||
|
@ -442,6 +446,7 @@ class IrEmitter : public DfsHloVisitorWithDefault,
|
||||||
// module's function list).
|
// module's function list).
|
||||||
std::unique_ptr<IrFunction> compute_function_;
|
std::unique_ptr<IrFunction> compute_function_;
|
||||||
llvm::IRBuilder<> b_;
|
llvm::IRBuilder<> b_;
|
||||||
|
mlir::MLIRContext* mlir_context_;
|
||||||
|
|
||||||
// The buffer allocation slice for the root of the computation being compiled.
|
// The buffer allocation slice for the root of the computation being compiled.
|
||||||
// Only relevant for thread local computations.
|
// Only relevant for thread local computations.
|
||||||
|
|
|
@ -0,0 +1,132 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/service/cpu/mlir_emitter.h"
|
||||||
|
|
||||||
|
#include "llvm/Linker/Linker.h"
|
||||||
|
#include "llvm/Transforms/IPO/Internalize.h"
|
||||||
|
#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" // from @llvm-project
|
||||||
|
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" // from @llvm-project
|
||||||
|
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" // from @llvm-project
|
||||||
|
#include "mlir/Dialect/Linalg/Passes.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/Module.h" // from @llvm-project
|
||||||
|
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||||
|
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||||
|
#include "mlir/Target/LLVMIR.h" // from @llvm-project
|
||||||
|
#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
namespace cpu {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// Lower an MLIR module to an LLVM module.
|
||||||
|
std::unique_ptr<llvm::Module> MakeLLVMModule(mlir::OwningModuleRef module) {
|
||||||
|
mlir::PassManager manager(module->getContext());
|
||||||
|
manager.addPass(mlir::createConvertLinalgToLoopsPass());
|
||||||
|
manager.addPass(mlir::createConvertLinalgToLLVMPass());
|
||||||
|
manager.addPass(mlir::createConvertVectorToLLVMPass());
|
||||||
|
manager.addPass(mlir::createLowerToLLVMPass());
|
||||||
|
CHECK(succeeded(manager.run(*module)));
|
||||||
|
return mlir::translateModuleToLLVMIR(*module);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get arguments to pass a memref to an mlir function.
|
||||||
|
void BuildViewForBuffer(llvm::SmallVectorImpl<llvm::Value *> *args,
|
||||||
|
llvm::IRBuilder<> *b, const Shape &opShape,
|
||||||
|
llvm::Value *op_val) {
|
||||||
|
llvm::Type *ty = op_val->getType();
|
||||||
|
while (auto aty = llvm::dyn_cast<llvm::ArrayType>(
|
||||||
|
llvm::cast<llvm::PointerType>(ty)->getElementType())) {
|
||||||
|
ty = aty->getElementType()->getPointerTo();
|
||||||
|
}
|
||||||
|
op_val = b->CreateBitCast(op_val, ty);
|
||||||
|
|
||||||
|
args->push_back(op_val); // Allocated pointer.
|
||||||
|
args->push_back(op_val); // Aligned pointer.
|
||||||
|
args->push_back(b->getInt64(0)); // Offset.
|
||||||
|
|
||||||
|
// Sizes.
|
||||||
|
for (int64 dim : opShape.dimensions()) {
|
||||||
|
args->push_back(b->getInt64(dim));
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t accumulated_stride = 1;
|
||||||
|
llvm::SmallVector<int64_t, 4> strides(opShape.rank(), 1);
|
||||||
|
for (int64 dim : LayoutUtil::MinorToMajor(opShape)) {
|
||||||
|
strides[dim] = accumulated_stride;
|
||||||
|
accumulated_stride *= opShape.dimensions(dim);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Strides.
|
||||||
|
for (int64 stride : strides) {
|
||||||
|
args->push_back(b->getInt64(stride));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
Status EmitMlirFuncAndCall(
|
||||||
|
mlir::MLIRContext *context, llvm::IRBuilder<> *b, const Shape &result_shape,
|
||||||
|
llvm::ArrayRef<Shape> operand_shapes, llvm::Value *result_ptr,
|
||||||
|
llvm::ArrayRef<llvm::Value *> operand_ptrs, llvm::StringRef func_name,
|
||||||
|
llvm::function_ref<void(mlir::OpBuilder *, mlir::FuncOp)> emitter) {
|
||||||
|
llvm::Module *llvm_module = b->GetInsertBlock()->getParent()->getParent();
|
||||||
|
mlir::Builder mlir_builder(context);
|
||||||
|
|
||||||
|
// Get memref types for the inputs and output.
|
||||||
|
TF_ASSIGN_OR_RETURN(mlir::Type ret_memref, ConvertTensorShapeToMemRefType(
|
||||||
|
result_shape, mlir_builder));
|
||||||
|
std::vector<mlir::Type> operand_types = {ret_memref};
|
||||||
|
for (int i = 0; i != operand_shapes.size(); ++i) {
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
mlir::Type op_memref,
|
||||||
|
ConvertTensorShapeToMemRefType(operand_shapes[i], mlir_builder));
|
||||||
|
operand_types.push_back(op_memref);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the function an call the emission callback.
|
||||||
|
mlir::Location loc = mlir::UnknownLoc::get(context);
|
||||||
|
auto function = mlir::FuncOp::create(
|
||||||
|
loc, func_name, mlir::FunctionType::get(operand_types, {}, context));
|
||||||
|
function.addEntryBlock();
|
||||||
|
mlir::OwningModuleRef mlir_module = mlir::ModuleOp::create(loc);
|
||||||
|
mlir_module->push_back(function);
|
||||||
|
mlir::OpBuilder op_builder(&function.getBody());
|
||||||
|
emitter(&op_builder, function);
|
||||||
|
|
||||||
|
// Now link it all into the main LLVM module.
|
||||||
|
auto mlir_llvm_module = MakeLLVMModule(std::move(mlir_module));
|
||||||
|
mlir_llvm_module->setDataLayout(llvm_module->getDataLayout());
|
||||||
|
llvm::Linker::linkModules(
|
||||||
|
*llvm_module, std::move(mlir_llvm_module), llvm::Linker::None,
|
||||||
|
[](llvm::Module &M, const llvm::StringSet<> &GVS) {
|
||||||
|
llvm::internalizeModule(M, [&GVS](const llvm::GlobalValue &GV) {
|
||||||
|
return !GV.hasName() || (GVS.count(GV.getName()) == 0);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
// And leave behind a call to the function generated by MLIR.
|
||||||
|
llvm::Function *func = llvm_module->getFunction(func_name);
|
||||||
|
llvm::SmallVector<llvm::Value *, 4> op_vals;
|
||||||
|
BuildViewForBuffer(&op_vals, b, result_shape, result_ptr);
|
||||||
|
for (int i = 0; i != operand_shapes.size(); ++i) {
|
||||||
|
BuildViewForBuffer(&op_vals, b, operand_shapes[i], operand_ptrs[i]);
|
||||||
|
}
|
||||||
|
b->CreateCall(func, op_vals);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cpu
|
||||||
|
} // namespace xla
|
|
@ -0,0 +1,43 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_MLIR_EMITTER_H_
|
||||||
|
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_MLIR_EMITTER_H_
|
||||||
|
|
||||||
|
#include "llvm/IR/IRBuilder.h"
|
||||||
|
#include "llvm/IR/Value.h"
|
||||||
|
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/Function.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||||
|
#include "tensorflow/compiler/xla/shape.h"
|
||||||
|
#include "tensorflow/compiler/xla/status.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
namespace cpu {
|
||||||
|
|
||||||
|
// Create a new MLIR function with the name `func_name`, populate it with
|
||||||
|
// `emitter` and create a call, passing it the buffers defined by
|
||||||
|
// resultShape/resultPtr and operandShapes/operandPtrs. The function is added to
|
||||||
|
// the LLVM module at `b`s insertion point.
|
||||||
|
Status EmitMlirFuncAndCall(
|
||||||
|
mlir::MLIRContext *context, llvm::IRBuilder<> *b, const Shape &result_shape,
|
||||||
|
llvm::ArrayRef<Shape> operand_shapes, llvm::Value *result_ptr,
|
||||||
|
llvm::ArrayRef<llvm::Value *> operand_ptrs, llvm::StringRef func_name,
|
||||||
|
llvm::function_ref<void(mlir::OpBuilder *, mlir::FuncOp)> emitter);
|
||||||
|
|
||||||
|
} // namespace cpu
|
||||||
|
} // namespace xla
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_MLIR_EMITTER_H_
|
Loading…
Reference in New Issue