[XLA:CPU] Make the Eigen matmul routines flexible around array alignment
This code will be used in a later CL where we will use implement batchdot in XLA:CPU by calling individual dot operations in a loop. PiperOrigin-RevId: 226091175
This commit is contained in:
parent
b2e8e62466
commit
51bac90e00
tensorflow/compiler/xla/service/cpu
@ -767,6 +767,8 @@ cc_library(
|
||||
":target_machine_features",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/service:computation_layout",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_casting_utils",
|
||||
"//tensorflow/compiler/xla/service:layout_assignment",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
|
@ -32,7 +32,11 @@ using tensorflow::int64;
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
bool Is16BytesAligned(void* ptr) {
|
||||
return reinterpret_cast<uintptr_t>(ptr) % 16 == 0;
|
||||
}
|
||||
|
||||
template <typename T, Eigen::AlignmentType Alignment>
|
||||
void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m,
|
||||
int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) {
|
||||
const xla::ExecutableRunOptions* run_options =
|
||||
@ -50,11 +54,11 @@ void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m,
|
||||
std::swap(rhs_rows, rhs_cols);
|
||||
}
|
||||
|
||||
const Eigen::TensorMap<Eigen::Tensor<const T, 2>, Eigen::Aligned> A(
|
||||
lhs, lhs_rows, lhs_cols);
|
||||
const Eigen::TensorMap<Eigen::Tensor<const T, 2>, Eigen::Aligned> B(
|
||||
rhs, rhs_rows, rhs_cols);
|
||||
Eigen::TensorMap<Eigen::Tensor<T, 2>, Eigen::Aligned> C(out, m, n);
|
||||
const Eigen::TensorMap<Eigen::Tensor<const T, 2>, Alignment> A(lhs, lhs_rows,
|
||||
lhs_cols);
|
||||
const Eigen::TensorMap<Eigen::Tensor<const T, 2>, Alignment> B(rhs, rhs_rows,
|
||||
rhs_cols);
|
||||
Eigen::TensorMap<Eigen::Tensor<T, 2>, Alignment> C(out, m, n);
|
||||
|
||||
typedef typename Eigen::Tensor<T, 2>::DimensionPair DimPair;
|
||||
int lhs_contract_dim = transpose_lhs ? 0 : 1;
|
||||
@ -69,14 +73,24 @@ void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m,
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MatMulImpl(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m,
|
||||
int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) {
|
||||
void MatMulDispatch(const void* run_options_ptr, T* out, T* lhs, T* rhs,
|
||||
int64 m, int64 n, int64 k, int32 transpose_lhs,
|
||||
int32 transpose_rhs) {
|
||||
bool all_buffers_16b_aligned =
|
||||
Is16BytesAligned(out) && Is16BytesAligned(lhs) && Is16BytesAligned(rhs);
|
||||
|
||||
if (!all_buffers_16b_aligned) {
|
||||
MatMul<T, Eigen::Unaligned>(run_options_ptr, out, lhs, rhs, m, n, k,
|
||||
transpose_lhs, transpose_rhs);
|
||||
return;
|
||||
}
|
||||
|
||||
if (m == 1 || n == 1) {
|
||||
// Despite being single threaded, this version of matrix * vector is faster.
|
||||
xla::EigenMatVec<T>(out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs);
|
||||
} else {
|
||||
MatMul<T>(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs,
|
||||
transpose_rhs);
|
||||
MatMul<T, Eigen::Aligned16>(run_options_ptr, out, lhs, rhs, m, n, k,
|
||||
transpose_lhs, transpose_rhs);
|
||||
}
|
||||
}
|
||||
|
||||
@ -86,20 +100,20 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF16(
|
||||
const void* run_options_ptr, Eigen::half* out, Eigen::half* lhs,
|
||||
Eigen::half* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs,
|
||||
int32 transpose_rhs) {
|
||||
MatMulImpl<Eigen::half>(run_options_ptr, out, lhs, rhs, m, n, k,
|
||||
transpose_lhs, transpose_rhs);
|
||||
MatMulDispatch<Eigen::half>(run_options_ptr, out, lhs, rhs, m, n, k,
|
||||
transpose_lhs, transpose_rhs);
|
||||
}
|
||||
|
||||
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF32(
|
||||
const void* run_options_ptr, float* out, float* lhs, float* rhs, int64 m,
|
||||
int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) {
|
||||
MatMulImpl<float>(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs,
|
||||
transpose_rhs);
|
||||
MatMulDispatch<float>(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs,
|
||||
transpose_rhs);
|
||||
}
|
||||
|
||||
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF64(
|
||||
const void* run_options_ptr, double* out, double* lhs, double* rhs, int64 m,
|
||||
int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) {
|
||||
MatMulImpl<double>(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs,
|
||||
transpose_rhs);
|
||||
MatMulDispatch<double>(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs,
|
||||
transpose_rhs);
|
||||
}
|
||||
|
@ -25,7 +25,11 @@ using tensorflow::int64;
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
bool Is16BytesAligned(void* ptr) {
|
||||
return reinterpret_cast<uintptr_t>(ptr) % 16 == 0;
|
||||
}
|
||||
|
||||
template <typename T, Eigen::AlignmentType Alignment>
|
||||
void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m,
|
||||
int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) {
|
||||
int64 lhs_rows = m;
|
||||
@ -40,11 +44,11 @@ void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m,
|
||||
std::swap(rhs_rows, rhs_cols);
|
||||
}
|
||||
|
||||
const Eigen::TensorMap<Eigen::Tensor<const T, 2>, Eigen::Aligned> A(
|
||||
lhs, lhs_rows, lhs_cols);
|
||||
const Eigen::TensorMap<Eigen::Tensor<const T, 2>, Eigen::Aligned> B(
|
||||
rhs, rhs_rows, rhs_cols);
|
||||
Eigen::TensorMap<Eigen::Tensor<T, 2>, Eigen::Aligned> C(out, m, n);
|
||||
const Eigen::TensorMap<Eigen::Tensor<const T, 2>, Alignment> A(lhs, lhs_rows,
|
||||
lhs_cols);
|
||||
const Eigen::TensorMap<Eigen::Tensor<const T, 2>, Alignment> B(rhs, rhs_rows,
|
||||
rhs_cols);
|
||||
Eigen::TensorMap<Eigen::Tensor<T, 2>, Alignment> C(out, m, n);
|
||||
|
||||
typedef typename Eigen::Tensor<T, 2>::DimensionPair DimPair;
|
||||
int lhs_contract_dim = transpose_lhs ? 0 : 1;
|
||||
@ -59,14 +63,22 @@ void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m,
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void SingleThreadedMatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs,
|
||||
int64 m, int64 n, int64 k, int32 transpose_lhs,
|
||||
int32 transpose_rhs) {
|
||||
void SingleThreadedMatMulDispatch(const void* run_options_ptr, T* out, T* lhs,
|
||||
T* rhs, int64 m, int64 n, int64 k,
|
||||
int32 transpose_lhs, int32 transpose_rhs) {
|
||||
bool all_buffers_16b_aligned =
|
||||
Is16BytesAligned(out) && Is16BytesAligned(lhs) && Is16BytesAligned(rhs);
|
||||
|
||||
if (!all_buffers_16b_aligned) {
|
||||
MatMul<T, Eigen::Unaligned>(run_options_ptr, out, lhs, rhs, m, n, k,
|
||||
transpose_lhs, transpose_rhs);
|
||||
}
|
||||
|
||||
if (m == 1 || n == 1) {
|
||||
xla::EigenMatVec<T>(out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs);
|
||||
} else {
|
||||
MatMul<T>(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs,
|
||||
transpose_rhs);
|
||||
MatMul<T, Eigen::Aligned16>(run_options_ptr, out, lhs, rhs, m, n, k,
|
||||
transpose_lhs, transpose_rhs);
|
||||
}
|
||||
}
|
||||
|
||||
@ -77,8 +89,8 @@ __xla_cpu_runtime_EigenSingleThreadedMatMulF16(
|
||||
const void* run_options_ptr, Eigen::half* out, Eigen::half* lhs,
|
||||
Eigen::half* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs,
|
||||
int32 transpose_rhs) {
|
||||
SingleThreadedMatMul<Eigen::half>(run_options_ptr, out, lhs, rhs, m, n, k,
|
||||
transpose_lhs, transpose_rhs);
|
||||
SingleThreadedMatMulDispatch<Eigen::half>(run_options_ptr, out, lhs, rhs, m,
|
||||
n, k, transpose_lhs, transpose_rhs);
|
||||
}
|
||||
|
||||
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void
|
||||
@ -87,8 +99,8 @@ __xla_cpu_runtime_EigenSingleThreadedMatMulF32(const void* run_options_ptr,
|
||||
float* rhs, int64 m, int64 n,
|
||||
int64 k, int32 transpose_lhs,
|
||||
int32 transpose_rhs) {
|
||||
SingleThreadedMatMul<float>(run_options_ptr, out, lhs, rhs, m, n, k,
|
||||
transpose_lhs, transpose_rhs);
|
||||
SingleThreadedMatMulDispatch<float>(run_options_ptr, out, lhs, rhs, m, n, k,
|
||||
transpose_lhs, transpose_rhs);
|
||||
}
|
||||
|
||||
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void
|
||||
@ -97,6 +109,6 @@ __xla_cpu_runtime_EigenSingleThreadedMatMulF64(const void* run_options_ptr,
|
||||
double* rhs, int64 m, int64 n,
|
||||
int64 k, int32 transpose_lhs,
|
||||
int32 transpose_rhs) {
|
||||
SingleThreadedMatMul<double>(run_options_ptr, out, lhs, rhs, m, n, k,
|
||||
transpose_lhs, transpose_rhs);
|
||||
SingleThreadedMatMulDispatch<double>(run_options_ptr, out, lhs, rhs, m, n, k,
|
||||
transpose_lhs, transpose_rhs);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user