Outline generated LLVM IR matrix-vector dot kernels
This is a code size optimization for cases that dot matrix-vectors of the same shape repeatedly, but is also a slight performance improvment (most likely due to better icache behavior). PiperOrigin-RevId: 177329302
This commit is contained in:
parent
7921d01ec8
commit
c572bc4fd7
@ -522,8 +522,10 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!primitive_util::IsFloatingPointType(dot_.shape().element_type()) &&
|
||||
!primitive_util::IsIntegralType(dot_.shape().element_type())) {
|
||||
PrimitiveType primitive_type = dot_.shape().element_type();
|
||||
|
||||
if (!primitive_util::IsFloatingPointType(primitive_type) &&
|
||||
!primitive_util::IsIntegralType(primitive_type)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -573,30 +575,50 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() {
|
||||
int64 tiling_factor = GetGemvTilingFactor();
|
||||
CHECK_GT(tiling_factor, 0);
|
||||
|
||||
llvm::Value* result_op = target_array_.GetBasePointer();
|
||||
llvm::Value* lhs_op =
|
||||
swap_operands ? rhs_array_.GetBasePointer() : lhs_array_.GetBasePointer();
|
||||
llvm::Value* rhs_op =
|
||||
swap_operands ? lhs_array_.GetBasePointer() : rhs_array_.GetBasePointer();
|
||||
|
||||
if (is_column_major_matrix_vector) {
|
||||
VLOG(2) << "Emitting column major matrix-vector multiply with m = " << m
|
||||
<< " and k = " << k;
|
||||
ColumnMajorMatrixVectorProductEmitter emitter(
|
||||
dot_.shape().element_type(), /*tile_rows=*/8,
|
||||
/*tile_cols=*/tiling_factor, m, k,
|
||||
swap_operands ? rhs_array_.GetBasePointer()
|
||||
: lhs_array_.GetBasePointer(),
|
||||
swap_operands ? lhs_array_.GetBasePointer()
|
||||
: rhs_array_.GetBasePointer(),
|
||||
target_array_.GetBasePointer(), ir_builder_);
|
||||
emitter.Emit();
|
||||
int64 tile_rows = 8;
|
||||
int64 tile_cols = tiling_factor;
|
||||
|
||||
string kernel_name = tensorflow::strings::StrCat(
|
||||
"col_major_gemv_", PrimitiveType_Name(primitive_type), "_", tile_rows,
|
||||
"_", tile_cols, "_", m, "_", k);
|
||||
|
||||
KernelSupportLibrary::EmitAndCallOutlinedKernel(
|
||||
ir_builder_, kernel_name, lhs_op, rhs_op, result_op,
|
||||
[this, tile_rows, tile_cols, m, k, primitive_type](
|
||||
llvm::Value* lhs_op, llvm::Value* rhs_op, llvm::Value* result_op) {
|
||||
ColumnMajorMatrixVectorProductEmitter emitter(
|
||||
primitive_type, tile_rows, tile_cols, m, k, lhs_op, rhs_op,
|
||||
result_op, ir_builder_);
|
||||
emitter.Emit();
|
||||
});
|
||||
} else {
|
||||
VLOG(2) << "Emitting row major matrix-vector multiply with m = " << m
|
||||
<< " and k = " << k;
|
||||
RowMajorMatrixVectorProductEmitter emitter(
|
||||
dot_.shape().element_type(), /*tile_rows=*/tiling_factor,
|
||||
/*tile_cols=*/8, m, k,
|
||||
swap_operands ? rhs_array_.GetBasePointer()
|
||||
: lhs_array_.GetBasePointer(),
|
||||
swap_operands ? lhs_array_.GetBasePointer()
|
||||
: rhs_array_.GetBasePointer(),
|
||||
target_array_.GetBasePointer(), ir_builder_);
|
||||
emitter.Emit();
|
||||
int64 tile_rows = tiling_factor;
|
||||
int64 tile_cols = 8;
|
||||
|
||||
string kernel_name = tensorflow::strings::StrCat(
|
||||
"row_major_gemv_", PrimitiveType_Name(primitive_type), "_", tile_rows,
|
||||
"_", tile_cols, "_", m, "_", k);
|
||||
|
||||
KernelSupportLibrary::EmitAndCallOutlinedKernel(
|
||||
ir_builder_, kernel_name, lhs_op, rhs_op, result_op,
|
||||
[this, tile_rows, tile_cols, m, k, primitive_type](
|
||||
llvm::Value* lhs_op, llvm::Value* rhs_op, llvm::Value* result_op) {
|
||||
RowMajorMatrixVectorProductEmitter emitter(
|
||||
primitive_type, tile_rows, tile_cols, m, k, lhs_op, rhs_op,
|
||||
result_op, ir_builder_);
|
||||
emitter.Emit();
|
||||
});
|
||||
}
|
||||
|
||||
return true;
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
|
||||
|
||||
namespace xla {
|
||||
void KernelSupportLibrary::For(
|
||||
@ -62,4 +63,47 @@ void KernelSupportLibrary::If(
|
||||
false_block_generator();
|
||||
llvm_ir::SetToLastInsertPoint(if_data.after_block, ir_builder_);
|
||||
}
|
||||
|
||||
void KernelSupportLibrary::EmitAndCallOutlinedKernel(
|
||||
llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece kernel_name,
|
||||
KernelSupportLibrary::ArgumentVector arguments,
|
||||
const std::function<void(KernelSupportLibrary::ArgumentVector)>&
|
||||
kernel_body_generator) {
|
||||
llvm::Module* module = ir_builder->GetInsertBlock()->getModule();
|
||||
llvm::Function* function =
|
||||
module->getFunction(llvm_ir::AsStringRef(kernel_name));
|
||||
if (!function) {
|
||||
VLOG(2) << "Generating kernel for " << kernel_name;
|
||||
std::vector<llvm::Type*> arg_types;
|
||||
std::transform(arguments.begin(), arguments.end(),
|
||||
std::back_inserter(arg_types),
|
||||
[](llvm::Value* arg) { return arg->getType(); });
|
||||
|
||||
auto* function_type = llvm::FunctionType::get(
|
||||
ir_builder->getVoidTy(), arg_types, /*isVarArg=*/false);
|
||||
|
||||
function = llvm::Function::Create(
|
||||
function_type, llvm::GlobalValue::InternalLinkage,
|
||||
llvm_ir::AsStringRef(kernel_name), module);
|
||||
|
||||
llvm::IRBuilder<>::InsertPointGuard guard(*ir_builder);
|
||||
|
||||
auto* entry_bb =
|
||||
llvm::BasicBlock::Create(ir_builder->getContext(), "entry", function);
|
||||
auto* return_inst = llvm::ReturnInst::Create(ir_builder->getContext(),
|
||||
/*retVal=*/nullptr, entry_bb);
|
||||
// Set the insert point to before return_inst.
|
||||
ir_builder->SetInsertPoint(return_inst);
|
||||
|
||||
std::vector<llvm::Value*> arg_values;
|
||||
std::transform(function->arg_begin(), function->arg_end(),
|
||||
std::back_inserter(arg_values), std::addressof<llvm::Value>);
|
||||
kernel_body_generator(arg_values);
|
||||
} else {
|
||||
VLOG(3) << "Re-using kernel for " << kernel_name;
|
||||
}
|
||||
|
||||
ir_builder->CreateCall(function, llvm_ir::AsArrayRef(arguments));
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -118,6 +118,38 @@ class KernelSupportLibrary {
|
||||
const std::function<void()>& true_block_generator,
|
||||
const std::function<void()>& false_block_generator = []() {});
|
||||
|
||||
using ArgumentVector = tensorflow::gtl::ArraySlice<llvm::Value*>;
|
||||
|
||||
// Generates the following control flow structure:
|
||||
//
|
||||
// define @`kernel_name`(arg0, arg1, ... arg`arguments.size()`) {
|
||||
// kernel_body_generator({arg0, arg1, ... arg`arguments.size()`});
|
||||
// }
|
||||
//
|
||||
// ...
|
||||
// call @`kernel_name`(arguments[0], arguments[1] ...)
|
||||
// ...
|
||||
//
|
||||
// If a function called `kernel_name` is already present in the module then
|
||||
// that function is re-used. In that sense we're using the llvm::Module as a
|
||||
// cache of outlined kernels, keyed by function name.
|
||||
static void EmitAndCallOutlinedKernel(
|
||||
llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece kernel_name,
|
||||
ArgumentVector arguments,
|
||||
const std::function<void(ArgumentVector)>& kernel_body_generator);
|
||||
|
||||
// Thin wrapper around the more general EmitAndCallOutlinedKernel above.
|
||||
static void EmitAndCallOutlinedKernel(
|
||||
llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece kernel_name,
|
||||
llvm::Value* arg0, llvm::Value* arg1, llvm::Value* arg2,
|
||||
const std::function<void(llvm::Value*, llvm::Value*, llvm::Value*)>&
|
||||
kernel_body_generator) {
|
||||
EmitAndCallOutlinedKernel(
|
||||
ir_builder, kernel_name, {arg0, arg1, arg2}, [&](ArgumentVector args) {
|
||||
kernel_body_generator(args[0], args[1], args[2]);
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
llvm::IRBuilder<>* ir_builder_;
|
||||
bool prevent_unrolling_;
|
||||
|
Loading…
x
Reference in New Issue
Block a user