[XLA:CPU] Make one of the tile dimensions in the LLVM IR GEMV tunable.
The tiling dimension corresponding to the number of vector registers in the tile can be changed easily. Expose this value as a backend specific flag so that we can experiment with it to find a good default value. This CL also fixes a bug exposed by a variable tiling factor in the row major GEMV implementation. This wasn't caught before because having tile_rows == tile_cols hides the bug. PiperOrigin-RevId: 175258553
This commit is contained in:
parent
3c41cb6bff
commit
23dc70389b
tensorflow/compiler/xla/service/cpu
@ -280,6 +280,7 @@ cc_library(
|
||||
srcs = ["dot_op_emitter.cc"],
|
||||
hdrs = ["dot_op_emitter.h"],
|
||||
deps = [
|
||||
":cpu_options",
|
||||
":cpu_runtime",
|
||||
":ir_emission_utils",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
@ -719,6 +720,7 @@ cc_library(
|
||||
hdrs = ["cpu_options.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla/service:hlo_module_config",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -15,11 +15,14 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
|
||||
|
||||
#include "tensorflow/core/lib/strings/numbers.h"
|
||||
|
||||
namespace {
|
||||
|
||||
const char* const kXlaParallelCpuOption = "xla_cpu_parallel";
|
||||
const char* const kXlaOptimizeForSizeCpuOption = "xla_cpu_optimize_for_size";
|
||||
const char* const kXlaDisableVectorizedReduce = "xla_disable_vectorized_reduce";
|
||||
const char* const kLlvmIrDotTilingFactor = "xla_llvm_dot_tiling_factor";
|
||||
|
||||
} // namespace
|
||||
|
||||
@ -45,6 +48,19 @@ bool VectorizedReduceDisabled(const HloModuleConfig& config) {
|
||||
return extra_options_map.count(kXlaOptimizeForSizeCpuOption) > 0;
|
||||
}
|
||||
|
||||
tensorflow::gtl::optional<int64> LlvmIrGemvTilingFactor(
|
||||
const HloModuleConfig& config) {
|
||||
const auto& extra_options_map =
|
||||
config.debug_options().xla_backend_extra_options();
|
||||
auto it = extra_options_map.find(kLlvmIrDotTilingFactor);
|
||||
int64 tiling_factor;
|
||||
if (it != extra_options_map.end() &&
|
||||
tensorflow::strings::safe_strto64(it->second, &tiling_factor)) {
|
||||
return tiling_factor;
|
||||
}
|
||||
return tensorflow::gtl::nullopt;
|
||||
}
|
||||
|
||||
} // namespace options
|
||||
} // namespace cpu
|
||||
} // namespace xla
|
||||
|
@ -27,6 +27,8 @@ namespace options {
|
||||
bool CpuParallelBackendRequested(const HloModuleConfig& config);
|
||||
bool OptimizeForSizeRequested(const HloModuleConfig& config);
|
||||
bool VectorizedReduceDisabled(const HloModuleConfig& config);
|
||||
tensorflow::gtl::optional<int64> LlvmIrGemvTilingFactor(
|
||||
const HloModuleConfig& config);
|
||||
|
||||
} // namespace options
|
||||
} // namespace cpu
|
||||
|
@ -366,7 +366,7 @@ class RowMajorMatrixVectorProductEmitter {
|
||||
result_(result),
|
||||
ir_builder_(ir_builder),
|
||||
ksl_(ir_builder_),
|
||||
vsl_(scalar_type_, /*vector_size=*/tile_rows_, ir_builder_, "") {
|
||||
vsl_(scalar_type_, /*vector_size=*/tile_cols_, ir_builder_, "") {
|
||||
CHECK(tile_cols_ > 0 && IsPowerOfTwo(static_cast<uint64>(tile_cols_)));
|
||||
}
|
||||
|
||||
@ -573,11 +573,15 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() {
|
||||
return false;
|
||||
}
|
||||
|
||||
int64 tiling_factor = GetGemvTilingFactor();
|
||||
CHECK_GT(tiling_factor, 0);
|
||||
|
||||
if (is_column_major_matrix_vector) {
|
||||
VLOG(2) << "Emitting column major matrix-vector multiply with m = " << m
|
||||
<< " and k = " << k;
|
||||
ColumnMajorMatrixVectorProductEmitter emitter(
|
||||
dot_.shape().element_type(), 8, 8, m, k,
|
||||
dot_.shape().element_type(), /*tile_rows=*/8,
|
||||
/*tile_cols=*/tiling_factor, m, k,
|
||||
swap_operands ? rhs_array_.GetBasePointer()
|
||||
: lhs_array_.GetBasePointer(),
|
||||
swap_operands ? lhs_array_.GetBasePointer()
|
||||
@ -588,7 +592,8 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() {
|
||||
VLOG(2) << "Emitting row major matrix-vector multiply with m = " << m
|
||||
<< " and k = " << k;
|
||||
RowMajorMatrixVectorProductEmitter emitter(
|
||||
dot_.shape().element_type(), 8, 8, m, k,
|
||||
dot_.shape().element_type(), /*tile_rows=*/tiling_factor,
|
||||
/*tile_cols=*/8, m, k,
|
||||
swap_operands ? rhs_array_.GetBasePointer()
|
||||
: lhs_array_.GetBasePointer(),
|
||||
swap_operands ? lhs_array_.GetBasePointer()
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_DOT_OP_EMITTER_H_
|
||||
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
|
||||
@ -105,6 +106,14 @@ class DotOpEmitter {
|
||||
// of rank 2 as well).
|
||||
MatMultDims GetMatMultDims() const;
|
||||
|
||||
// When doing a tiled GEMV in LLVM IR, a "tile" consists of this many vector
|
||||
// registers.
|
||||
int64 GetGemvTilingFactor() const {
|
||||
const int64 kDefaultTilingFactor = 8;
|
||||
return options::LlvmIrGemvTilingFactor(hlo_module_config_)
|
||||
.value_or(kDefaultTilingFactor);
|
||||
}
|
||||
|
||||
const HloInstruction& dot_;
|
||||
const bool transpose_lhs_;
|
||||
const bool transpose_rhs_;
|
||||
|
Loading…
Reference in New Issue
Block a user