[XLA:CPU] Enable LLVM IR GEMM for small matrices.

We seem to be comfortably better than Eigen for matrices smaller than
32x128x128 on my local machine.

Also rename "block panel" to "tiled small gemm".  "block panel" was never the
right name for the kernel -- "block panel" is a way of tiling a large gemm into
smaller chunks, not a way of implementing the small gemm.

PiperOrigin-RevId: 209671700
This commit is contained in:
Sanjoy Das 2018-08-21 15:33:21 -07:00 committed by TensorFlower Gardener
parent b1b2cb38f2
commit daf992961e
2 changed files with 46 additions and 36 deletions

View File

@ -621,19 +621,19 @@ void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue(
}
// This class implements a tiled matrix multiplication algorithm, intended for
// use as the innermost GEBP loop in a GEMM kernel (GEBP is described in "Goto,
// Kazushige, and Robert Van De Geijn. "High-performance implementation of the
// level-3 BLAS." ACM Transactions on Mathematical Software (TOMS) 35.1 (2008):
// 4).
// multiplying small matrices that don't need cache tiling.
//
// In the future this can be used as the innermost GEBP loop in a GEMM kernel as
// described in "Goto, Kazushige, and Robert A. Geijn. "Anatomy of
// high-performance matrix multiplication." ACM Transactions on Mathematical
// Software (TOMS) 34.3 (2008): 12.".
//
// This only supports canonical dot operations (i.e. where the lhs contraction
// dimension is 1 and the rhs contraction dimension is 0) over row major
// matrices.
class MatrixMatrixBlockPanelEmitter {
class TiledSmallGemmEmitter {
public:
// Describe the dimensions of the GEBP kernel. These will usually not be the
// dimensions of the GEMM itself, the GEMM will usually be broken up into GEBP
// kernels with smaller dimensions.
// Describe the dimensions of the kernel.
class Dimensions {
public:
explicit Dimensions(int64 m, int64 k, int64 n) : m_(m), k_(k), n_(n) {}
@ -652,9 +652,9 @@ class MatrixMatrixBlockPanelEmitter {
const int64 n_;
};
// Represents the configuration of the GEBP emitter. The LLVM IR emitted by
// the emitter, modulo the LLVM values holding the input and output buffers,
// must be a function of the instance of `Config` passed to it.
// Represents the configuration of the emitter. The LLVM IR emitted by the
// emitter, modulo the LLVM values holding the input and output buffers, must
// be a function of the instance of `Config` passed to it.
//
// `dims` holds the matrix multiplication dimensions.
//
@ -688,7 +688,7 @@ class MatrixMatrixBlockPanelEmitter {
string GetCacheKey() const {
return tensorflow::strings::StrCat(
"gebp_", PrimitiveType_Name(scalar_type()), "_", dims().ToString(),
"gemm_", PrimitiveType_Name(scalar_type()), "_", dims().ToString(),
"_", max_vectorization_width(), "_", min_vectorization_width(), "_",
tile_size_m(), "_", tile_size_k());
}
@ -712,11 +712,11 @@ class MatrixMatrixBlockPanelEmitter {
int64 tile_size_k_;
};
// Creates an instance of MatrixMatrixBlockPanelEmitter that matrix-multiplies
// Creates an instance of TiledSmallGemmEmitter that matrix-multiplies
// `lhs` with `rhs` and stores the result in `result`.
explicit MatrixMatrixBlockPanelEmitter(Config config, llvm::Value* lhs,
llvm::Value* rhs, llvm::Value* result,
llvm::IRBuilder<>* b)
explicit TiledSmallGemmEmitter(Config config, llvm::Value* lhs,
llvm::Value* rhs, llvm::Value* result,
llvm::IRBuilder<>* b)
: lhs_(lhs),
rhs_(rhs),
result_(result),
@ -780,9 +780,9 @@ class MatrixMatrixBlockPanelEmitter {
KernelSupportLibrary ksl_;
};
void MatrixMatrixBlockPanelEmitter::Emit() { HandleResiduesOnN(); }
void TiledSmallGemmEmitter::Emit() { HandleResiduesOnN(); }
void MatrixMatrixBlockPanelEmitter::HandleResiduesOnN() {
void TiledSmallGemmEmitter::HandleResiduesOnN() {
// We can only iterate the `n` dimension for an extent that is divisible by
// the vectorization width. So we emit an outer loop that first processes the
// largest extent in `n` that is divisible by max_vectorization_width, then
@ -799,7 +799,7 @@ void MatrixMatrixBlockPanelEmitter::HandleResiduesOnN() {
int64 n_end = dims().n() - (dims().n() % current_vectorization_width);
if (n_start != n_end) {
VectorSupportLibrary vsl(scalar_type(), current_vectorization_width, b_,
"gebp");
"gemm");
HandleResiduesOnK(&vsl, GetInt64(n_start), GetInt64(n_end));
n_start = n_end;
}
@ -813,7 +813,7 @@ void MatrixMatrixBlockPanelEmitter::HandleResiduesOnN() {
}
if (n_start != dims().n()) {
VectorSupportLibrary vsl(scalar_type(), 1, b_, "gebp");
VectorSupportLibrary vsl(scalar_type(), 1, b_, "gemm");
ksl_.ForReturnVoid("epi.n", n_start, dims().n(), 1, [&](llvm::Value* n_i) {
llvm::Value* n_i_next = b_->CreateAdd(n_i, b_->getInt64(1));
HandleResiduesOnK(&vsl, n_i, n_i_next);
@ -821,9 +821,9 @@ void MatrixMatrixBlockPanelEmitter::HandleResiduesOnN() {
}
}
void MatrixMatrixBlockPanelEmitter::HandleResiduesOnK(VectorSupportLibrary* vsl,
llvm::Value* n_start,
llvm::Value* n_end) {
void TiledSmallGemmEmitter::HandleResiduesOnK(VectorSupportLibrary* vsl,
llvm::Value* n_start,
llvm::Value* n_end) {
int64 k_start = 0;
int64 k_end = dims().k() - (dims().k() % tile_size_k());
if (k_end != k_start) {
@ -838,7 +838,7 @@ void MatrixMatrixBlockPanelEmitter::HandleResiduesOnK(VectorSupportLibrary* vsl,
}
}
void MatrixMatrixBlockPanelEmitter::HandleResiduesOnM(
void TiledSmallGemmEmitter::HandleResiduesOnM(
VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start,
llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end) {
const int64 m_end = dims().m() - dims().m() % tile_size_m();
@ -921,7 +921,7 @@ void MatrixMatrixBlockPanelEmitter::HandleResiduesOnM(
// +-------------------+-------------------+-------------------+---------
// | a0*p0+b0*q0+c0*r0 | a0*p1+b0*q1+c0*r1 | a0*p2+b0*q2+c0*r2 | ...
// +-------------------+-------------------+-------------------+---------
void MatrixMatrixBlockPanelEmitter::EmitTiledGemm(
void TiledSmallGemmEmitter::EmitTiledGemm(
VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start,
llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end,
int64 tile_size_m, llvm::Value* m_start, llvm::Value* m_end) {
@ -1001,12 +1001,22 @@ DotOpEmitter::DotOpEmitter(const HloInstruction& dot,
return dot_emitter.Emit();
}
bool DotOpEmitter::EmitExperimentalGebpDotIfEnabled(
bool DotOpEmitter::EmitSmallGemmIfProfitable(
const DotOpEmitter::MatMultDims& mat_mult_dims) {
if (!EnableExperimentalLlvmIrGemm() || ShouldUseMultiThreadedEigen()) {
if (ShouldUseMultiThreadedEigen()) {
return false;
}
if (!EnableExperimentalLlvmIrGemm()) {
// TODO(sanjoy): We should make these numbers micro-arch specific.
bool small_gemm = mat_mult_dims.k <= 128 &&
((mat_mult_dims.m <= 32 && mat_mult_dims.n <= 128) ||
(mat_mult_dims.m <= 128 && mat_mult_dims.n <= 32));
if (!small_gemm) {
return false;
}
}
if (mat_mult_dims.lhs_non_canonical || mat_mult_dims.rhs_non_canonical) {
return false;
}
@ -1054,15 +1064,15 @@ bool DotOpEmitter::EmitExperimentalGebpDotIfEnabled(
std::tie(tile_size_m, tile_size_k, tile_size_n_in_vector_width) =
GetGemmTileSize();
MatrixMatrixBlockPanelEmitter::Config config(
TiledSmallGemmEmitter::Config config(
/*scalar_type=*/primitive_type,
MatrixMatrixBlockPanelEmitter::Dimensions{/*m=*/m, /*k=*/k, /*n=*/n},
TiledSmallGemmEmitter::Dimensions{/*m=*/m, /*k=*/k, /*n=*/n},
/*max_vectorization_width=*/max_target_vector_width,
/*max_vector_count=*/tile_size_n_in_vector_width,
/*min_vectorization_width=*/std::min<int64>(4, max_target_vector_width),
/*tile_size_m=*/tile_size_m, /*tile_size_k=*/tile_size_k);
VLOG(2) << "Emitting GEBP kernel in LLVM IR with config "
VLOG(2) << "Emitting GEMM kernel in LLVM IR with config "
<< config.GetCacheKey();
const bool enable_fast_math =
@ -1075,10 +1085,10 @@ bool DotOpEmitter::EmitExperimentalGebpDotIfEnabled(
/*optimize_for_size=*/optimize_for_size, b_, config.GetCacheKey(), lhs,
rhs, target,
[this, config](llvm::Value* lhs, llvm::Value* rhs, llvm::Value* target) {
MatrixMatrixBlockPanelEmitter gebp_emitter(config, /*lhs=*/lhs,
/*rhs=*/rhs,
/*result=*/target, b_);
gebp_emitter.Emit();
TiledSmallGemmEmitter small_gemm_emitter(config, /*lhs=*/lhs,
/*rhs=*/rhs,
/*result=*/target, b_);
small_gemm_emitter.Emit();
});
return true;
@ -1136,7 +1146,7 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() {
}
if (!is_column_major_matrix_vector && !is_row_major_matrix_vector) {
return EmitExperimentalGebpDotIfEnabled(mat_mult_dims);
return EmitSmallGemmIfProfitable(mat_mult_dims);
}
int64 tiling_factor = GetGemvTilingFactor();

View File

@ -121,7 +121,7 @@ class DotOpEmitter {
// of rank 2 as well).
MatMultDims GetMatMultDims() const;
bool EmitExperimentalGebpDotIfEnabled(const MatMultDims& mat_mult_dims);
bool EmitSmallGemmIfProfitable(const MatMultDims& mat_mult_dims);
// When doing a tiled GEMV in LLVM IR, a "tile" consists of this many vector
// registers.