Outline generated LLVM IR matrix-vector dot kernels

This is a code size optimization for cases that dot matrix-vectors of the same
shape repeatedly, but is also a slight performance improvment (most likely due
to better icache behavior).

PiperOrigin-RevId: 177329302
This commit is contained in:
Sanjoy Das 2017-11-29 10:37:28 -08:00 committed by TensorFlower Gardener
parent 7921d01ec8
commit c572bc4fd7
3 changed files with 118 additions and 20 deletions

View File

@ -522,8 +522,10 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() {
return false;
}
if (!primitive_util::IsFloatingPointType(dot_.shape().element_type()) &&
!primitive_util::IsIntegralType(dot_.shape().element_type())) {
PrimitiveType primitive_type = dot_.shape().element_type();
if (!primitive_util::IsFloatingPointType(primitive_type) &&
!primitive_util::IsIntegralType(primitive_type)) {
return false;
}
@ -573,30 +575,50 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() {
int64 tiling_factor = GetGemvTilingFactor();
CHECK_GT(tiling_factor, 0);
llvm::Value* result_op = target_array_.GetBasePointer();
llvm::Value* lhs_op =
swap_operands ? rhs_array_.GetBasePointer() : lhs_array_.GetBasePointer();
llvm::Value* rhs_op =
swap_operands ? lhs_array_.GetBasePointer() : rhs_array_.GetBasePointer();
if (is_column_major_matrix_vector) {
VLOG(2) << "Emitting column major matrix-vector multiply with m = " << m
<< " and k = " << k;
ColumnMajorMatrixVectorProductEmitter emitter(
dot_.shape().element_type(), /*tile_rows=*/8,
/*tile_cols=*/tiling_factor, m, k,
swap_operands ? rhs_array_.GetBasePointer()
: lhs_array_.GetBasePointer(),
swap_operands ? lhs_array_.GetBasePointer()
: rhs_array_.GetBasePointer(),
target_array_.GetBasePointer(), ir_builder_);
emitter.Emit();
int64 tile_rows = 8;
int64 tile_cols = tiling_factor;
string kernel_name = tensorflow::strings::StrCat(
"col_major_gemv_", PrimitiveType_Name(primitive_type), "_", tile_rows,
"_", tile_cols, "_", m, "_", k);
KernelSupportLibrary::EmitAndCallOutlinedKernel(
ir_builder_, kernel_name, lhs_op, rhs_op, result_op,
[this, tile_rows, tile_cols, m, k, primitive_type](
llvm::Value* lhs_op, llvm::Value* rhs_op, llvm::Value* result_op) {
ColumnMajorMatrixVectorProductEmitter emitter(
primitive_type, tile_rows, tile_cols, m, k, lhs_op, rhs_op,
result_op, ir_builder_);
emitter.Emit();
});
} else {
VLOG(2) << "Emitting row major matrix-vector multiply with m = " << m
<< " and k = " << k;
RowMajorMatrixVectorProductEmitter emitter(
dot_.shape().element_type(), /*tile_rows=*/tiling_factor,
/*tile_cols=*/8, m, k,
swap_operands ? rhs_array_.GetBasePointer()
: lhs_array_.GetBasePointer(),
swap_operands ? lhs_array_.GetBasePointer()
: rhs_array_.GetBasePointer(),
target_array_.GetBasePointer(), ir_builder_);
emitter.Emit();
int64 tile_rows = tiling_factor;
int64 tile_cols = 8;
string kernel_name = tensorflow::strings::StrCat(
"row_major_gemv_", PrimitiveType_Name(primitive_type), "_", tile_rows,
"_", tile_cols, "_", m, "_", k);
KernelSupportLibrary::EmitAndCallOutlinedKernel(
ir_builder_, kernel_name, lhs_op, rhs_op, result_op,
[this, tile_rows, tile_cols, m, k, primitive_type](
llvm::Value* lhs_op, llvm::Value* rhs_op, llvm::Value* result_op) {
RowMajorMatrixVectorProductEmitter emitter(
primitive_type, tile_rows, tile_cols, m, k, lhs_op, rhs_op,
result_op, ir_builder_);
emitter.Emit();
});
}
return true;

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
namespace xla {
void KernelSupportLibrary::For(
@ -62,4 +63,47 @@ void KernelSupportLibrary::If(
false_block_generator();
llvm_ir::SetToLastInsertPoint(if_data.after_block, ir_builder_);
}
void KernelSupportLibrary::EmitAndCallOutlinedKernel(
llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece kernel_name,
KernelSupportLibrary::ArgumentVector arguments,
const std::function<void(KernelSupportLibrary::ArgumentVector)>&
kernel_body_generator) {
llvm::Module* module = ir_builder->GetInsertBlock()->getModule();
llvm::Function* function =
module->getFunction(llvm_ir::AsStringRef(kernel_name));
if (!function) {
VLOG(2) << "Generating kernel for " << kernel_name;
std::vector<llvm::Type*> arg_types;
std::transform(arguments.begin(), arguments.end(),
std::back_inserter(arg_types),
[](llvm::Value* arg) { return arg->getType(); });
auto* function_type = llvm::FunctionType::get(
ir_builder->getVoidTy(), arg_types, /*isVarArg=*/false);
function = llvm::Function::Create(
function_type, llvm::GlobalValue::InternalLinkage,
llvm_ir::AsStringRef(kernel_name), module);
llvm::IRBuilder<>::InsertPointGuard guard(*ir_builder);
auto* entry_bb =
llvm::BasicBlock::Create(ir_builder->getContext(), "entry", function);
auto* return_inst = llvm::ReturnInst::Create(ir_builder->getContext(),
/*retVal=*/nullptr, entry_bb);
// Set the insert point to before return_inst.
ir_builder->SetInsertPoint(return_inst);
std::vector<llvm::Value*> arg_values;
std::transform(function->arg_begin(), function->arg_end(),
std::back_inserter(arg_values), std::addressof<llvm::Value>);
kernel_body_generator(arg_values);
} else {
VLOG(3) << "Re-using kernel for " << kernel_name;
}
ir_builder->CreateCall(function, llvm_ir::AsArrayRef(arguments));
}
} // namespace xla

View File

@ -118,6 +118,38 @@ class KernelSupportLibrary {
const std::function<void()>& true_block_generator,
const std::function<void()>& false_block_generator = []() {});
using ArgumentVector = tensorflow::gtl::ArraySlice<llvm::Value*>;
// Generates the following control flow structure:
//
// define @`kernel_name`(arg0, arg1, ... arg`arguments.size()`) {
// kernel_body_generator({arg0, arg1, ... arg`arguments.size()`});
// }
//
// ...
// call @`kernel_name`(arguments[0], arguments[1] ...)
// ...
//
// If a function called `kernel_name` is already present in the module then
// that function is re-used. In that sense we're using the llvm::Module as a
// cache of outlined kernels, keyed by function name.
static void EmitAndCallOutlinedKernel(
llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece kernel_name,
ArgumentVector arguments,
const std::function<void(ArgumentVector)>& kernel_body_generator);
// Thin wrapper around the more general EmitAndCallOutlinedKernel above.
static void EmitAndCallOutlinedKernel(
llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece kernel_name,
llvm::Value* arg0, llvm::Value* arg1, llvm::Value* arg2,
const std::function<void(llvm::Value*, llvm::Value*, llvm::Value*)>&
kernel_body_generator) {
EmitAndCallOutlinedKernel(
ir_builder, kernel_name, {arg0, arg1, arg2}, [&](ArgumentVector args) {
kernel_body_generator(args[0], args[1], args[2]);
});
}
private:
llvm::IRBuilder<>* ir_builder_;
bool prevent_unrolling_;