[XLA:CPU] Wire up C64/C128 matmul to Eigen
This is much faster than a naive loop. Also add some more testing now that we can support it in the evaluator. PiperOrigin-RevId: 313407740 Change-Id: I692de60af47e86a269ab4d121e97d2b472b7a8e3
This commit is contained in:
parent
68ededda03
commit
a5fef39a38
@ -67,6 +67,10 @@ extern const char* const kEigenMatMulF32SymbolName =
|
||||
"__xla_cpu_runtime_EigenMatMulF32";
|
||||
extern const char* const kEigenMatMulF64SymbolName =
|
||||
"__xla_cpu_runtime_EigenMatMulF64";
|
||||
extern const char* const kEigenMatMulC64SymbolName =
|
||||
"__xla_cpu_runtime_EigenMatMulC64";
|
||||
extern const char* const kEigenMatMulC128SymbolName =
|
||||
"__xla_cpu_runtime_EigenMatMulC128";
|
||||
extern const char* const kEigenMatMulS32SymbolName =
|
||||
"__xla_cpu_runtime_EigenMatMulS32";
|
||||
extern const char* const kMKLConvF32SymbolName = "__xla_cpu_runtime_MKLConvF32";
|
||||
@ -91,6 +95,10 @@ extern const char* const kEigenSingleThreadedMatMulF32SymbolName =
|
||||
"__xla_cpu_runtime_EigenSingleThreadedMatMulF32";
|
||||
extern const char* const kEigenSingleThreadedMatMulF64SymbolName =
|
||||
"__xla_cpu_runtime_EigenSingleThreadedMatMulF64";
|
||||
extern const char* const kEigenSingleThreadedMatMulC64SymbolName =
|
||||
"__xla_cpu_runtime_EigenSingleThreadedMatMulC64";
|
||||
extern const char* const kEigenSingleThreadedMatMulC128SymbolName =
|
||||
"__xla_cpu_runtime_EigenSingleThreadedMatMulC128";
|
||||
extern const char* const kEigenSingleThreadedMatMulS32SymbolName =
|
||||
"__xla_cpu_runtime_EigenSingleThreadedMatMulS32";
|
||||
extern const char* const kEigenSingleThreadedConvF16SymbolName =
|
||||
|
||||
@ -46,6 +46,8 @@ namespace runtime {
|
||||
extern const char* const kEigenMatMulF16SymbolName;
|
||||
extern const char* const kEigenMatMulF32SymbolName;
|
||||
extern const char* const kEigenMatMulF64SymbolName;
|
||||
extern const char* const kEigenMatMulC64SymbolName;
|
||||
extern const char* const kEigenMatMulC128SymbolName;
|
||||
extern const char* const kEigenMatMulS32SymbolName;
|
||||
extern const char* const kMKLConvF32SymbolName;
|
||||
extern const char* const kMKLMatMulF32SymbolName;
|
||||
@ -59,6 +61,8 @@ extern const char* const kEigenSingleThreadedFftSymbolName;
|
||||
extern const char* const kEigenSingleThreadedMatMulF16SymbolName;
|
||||
extern const char* const kEigenSingleThreadedMatMulF32SymbolName;
|
||||
extern const char* const kEigenSingleThreadedMatMulF64SymbolName;
|
||||
extern const char* const kEigenSingleThreadedMatMulC64SymbolName;
|
||||
extern const char* const kEigenSingleThreadedMatMulC128SymbolName;
|
||||
extern const char* const kEigenSingleThreadedMatMulS32SymbolName;
|
||||
extern const char* const kEigenSingleThreadedConvF16SymbolName;
|
||||
extern const char* const kEigenSingleThreadedConvF32SymbolName;
|
||||
|
||||
@ -657,6 +657,8 @@ Status DotOpEmitter::EmitCallToRuntime() {
|
||||
bool multi_threaded = ShouldUseMultiThreadedEigen(hlo_module_config_);
|
||||
bool use_mkl_dnn = hlo_module_config_.debug_options().xla_cpu_use_mkl_dnn();
|
||||
PrimitiveType type = target_array_.GetShape().element_type();
|
||||
llvm::Function* function = b_->GetInsertBlock()->getParent();
|
||||
llvm::Module* module = function->getParent();
|
||||
llvm::Type* float_type;
|
||||
const char* fn_name;
|
||||
switch (type) {
|
||||
@ -684,6 +686,18 @@ Status DotOpEmitter::EmitCallToRuntime() {
|
||||
: runtime::kEigenSingleThreadedMatMulF64SymbolName);
|
||||
float_type = b_->getDoubleTy();
|
||||
break;
|
||||
case C64:
|
||||
fn_name = multi_threaded
|
||||
? runtime::kEigenMatMulC64SymbolName
|
||||
: runtime::kEigenSingleThreadedMatMulC64SymbolName;
|
||||
float_type = llvm_ir::PrimitiveTypeToIrType(C64, module);
|
||||
break;
|
||||
case C128:
|
||||
fn_name = multi_threaded
|
||||
? runtime::kEigenMatMulC128SymbolName
|
||||
: runtime::kEigenSingleThreadedMatMulC128SymbolName;
|
||||
float_type = llvm_ir::PrimitiveTypeToIrType(C128, module);
|
||||
break;
|
||||
case S32:
|
||||
fn_name = multi_threaded
|
||||
? runtime::kEigenMatMulS32SymbolName
|
||||
@ -705,9 +719,6 @@ Status DotOpEmitter::EmitCallToRuntime() {
|
||||
int64_type, int64_type, int64_type, int32_type, int32_type},
|
||||
/*isVarArg=*/false);
|
||||
|
||||
llvm::Function* function = b_->GetInsertBlock()->getParent();
|
||||
llvm::Module* module = function->getParent();
|
||||
|
||||
llvm::FunctionCallee matmul_func =
|
||||
module->getOrInsertFunction(fn_name, matmul_type);
|
||||
if (auto* fn = llvm::dyn_cast<llvm::Function>(matmul_func.getCallee())) {
|
||||
@ -853,9 +864,11 @@ bool AreGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape,
|
||||
<< output_shape.DebugString();
|
||||
|
||||
switch (output_shape.element_type()) {
|
||||
case F64:
|
||||
case F32:
|
||||
case F16:
|
||||
case F32:
|
||||
case F64:
|
||||
case C64:
|
||||
case C128:
|
||||
case S32:
|
||||
return IsRank2(lhs_shape) && IsRank2(rhs_shape) && IsRank2(output_shape);
|
||||
default:
|
||||
@ -904,7 +917,9 @@ bool CanEmitTiledLlvmIrGemm(
|
||||
return false;
|
||||
}
|
||||
|
||||
if (dot_info.result_shape.element_type() == F16) {
|
||||
if (dot_info.result_shape.element_type() == F16 ||
|
||||
dot_info.result_shape.element_type() == C64 ||
|
||||
dot_info.result_shape.element_type() == C128) {
|
||||
// TODO(sanjoy): This is probably easy to fix, but I want to keep the CL
|
||||
// adding this comment NFC.
|
||||
return false;
|
||||
|
||||
@ -114,6 +114,22 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF64(
|
||||
transpose_rhs);
|
||||
}
|
||||
|
||||
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulC64(
|
||||
const void* run_options_ptr, std::complex<float>* out,
|
||||
std::complex<float>* lhs, std::complex<float>* rhs, int64 m, int64 n,
|
||||
int64 k, int32 transpose_lhs, int32 transpose_rhs) {
|
||||
MatMulDispatch<std::complex<float>>(run_options_ptr, out, lhs, rhs, m, n, k,
|
||||
transpose_lhs, transpose_rhs);
|
||||
}
|
||||
|
||||
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulC128(
|
||||
const void* run_options_ptr, std::complex<double>* out,
|
||||
std::complex<double>* lhs, std::complex<double>* rhs, int64 m, int64 n,
|
||||
int64 k, int32 transpose_lhs, int32 transpose_rhs) {
|
||||
MatMulDispatch<std::complex<double>>(run_options_ptr, out, lhs, rhs, m, n, k,
|
||||
transpose_lhs, transpose_rhs);
|
||||
}
|
||||
|
||||
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulS32(
|
||||
const void* run_options_ptr, int32* out, int32* lhs, int32* rhs, int64 m,
|
||||
int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) {
|
||||
|
||||
@ -16,6 +16,8 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATMUL_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATMUL_H_
|
||||
|
||||
#include <complex>
|
||||
|
||||
#include "third_party/eigen3/Eigen/Core"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
@ -44,6 +46,18 @@ extern void __xla_cpu_runtime_EigenMatMulF64(
|
||||
tensorflow::int64 k, tensorflow::int32 transpose_lhs,
|
||||
tensorflow::int32 transpose_rhs);
|
||||
|
||||
extern void __xla_cpu_runtime_EigenMatMulC64(
|
||||
const void* run_options_ptr, std::complex<float>* out,
|
||||
std::complex<float>* lhs, std::complex<float>* rhs, tensorflow::int64 m,
|
||||
tensorflow::int64 n, tensorflow::int64 k, tensorflow::int32 transpose_lhs,
|
||||
tensorflow::int32 transpose_rhs);
|
||||
|
||||
extern void __xla_cpu_runtime_EigenMatMulC128(
|
||||
const void* run_options_ptr, std::complex<double>* out,
|
||||
std::complex<double>* lhs, std::complex<double>* rhs, tensorflow::int64 m,
|
||||
tensorflow::int64 n, tensorflow::int64 k, tensorflow::int32 transpose_lhs,
|
||||
tensorflow::int32 transpose_rhs);
|
||||
|
||||
extern void __xla_cpu_runtime_EigenMatMulS32(
|
||||
const void* /* xla::ExecutableRunOptions* */ run_options_ptr,
|
||||
tensorflow::int32* out, tensorflow::int32* lhs, tensorflow::int32* rhs,
|
||||
|
||||
@ -112,6 +112,24 @@ __xla_cpu_runtime_EigenSingleThreadedMatMulF64(const void* run_options_ptr,
|
||||
transpose_lhs, transpose_rhs);
|
||||
}
|
||||
|
||||
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void
|
||||
__xla_cpu_runtime_EigenSingleThreadedMatMulC64(
|
||||
const void* run_options_ptr, std::complex<float>* out,
|
||||
std::complex<float>* lhs, std::complex<float>* rhs, int64 m, int64 n,
|
||||
int64 k, int32 transpose_lhs, int32 transpose_rhs) {
|
||||
SingleThreadedMatMulDispatch<std::complex<float>>(
|
||||
run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs);
|
||||
}
|
||||
|
||||
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void
|
||||
__xla_cpu_runtime_EigenSingleThreadedMatMulC128(
|
||||
const void* run_options_ptr, std::complex<double>* out,
|
||||
std::complex<double>* lhs, std::complex<double>* rhs, int64 m, int64 n,
|
||||
int64 k, int32 transpose_lhs, int32 transpose_rhs) {
|
||||
SingleThreadedMatMulDispatch<std::complex<double>>(
|
||||
run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs);
|
||||
}
|
||||
|
||||
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void
|
||||
__xla_cpu_runtime_EigenSingleThreadedMatMulS32(const void* run_options_ptr,
|
||||
int32* out, int32* lhs,
|
||||
|
||||
@ -16,6 +16,8 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_MATMUL_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_MATMUL_H_
|
||||
|
||||
#include <complex>
|
||||
|
||||
#include "third_party/eigen3/Eigen/Core"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
@ -44,6 +46,20 @@ extern void __xla_cpu_runtime_EigenSingleThreadedMatMulF64(
|
||||
tensorflow::int64 k, tensorflow::int32 transpose_lhs,
|
||||
tensorflow::int32 transpose_rhs);
|
||||
|
||||
extern void __xla_cpu_runtime_EigenSingleThreadedMatMulC64(
|
||||
const void* /* xla::ExecutableRunOptions* */ run_options_ptr,
|
||||
std::complex<float>* out, std::complex<float>* lhs,
|
||||
std::complex<float>* rhs, tensorflow::int64 m, tensorflow::int64 n,
|
||||
tensorflow::int64 k, tensorflow::int32 transpose_lhs,
|
||||
tensorflow::int32 transpose_rhs);
|
||||
|
||||
extern void __xla_cpu_runtime_EigenSingleThreadedMatMulC128(
|
||||
const void* /* xla::ExecutableRunOptions* */ run_options_ptr,
|
||||
std::complex<double>* out, std::complex<double>* lhs,
|
||||
std::complex<double>* rhs, tensorflow::int64 m, tensorflow::int64 n,
|
||||
tensorflow::int64 k, tensorflow::int32 transpose_lhs,
|
||||
tensorflow::int32 transpose_rhs);
|
||||
|
||||
extern void __xla_cpu_runtime_EigenSingleThreadedMatMulS32(
|
||||
const void* /* xla::ExecutableRunOptions* */ run_options_ptr,
|
||||
tensorflow::int32* out, tensorflow::int32* lhs, tensorflow::int32* rhs,
|
||||
|
||||
@ -246,6 +246,8 @@ bool RegisterKnownJITSymbols() {
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF16);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF32);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF64);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulC64);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulC128);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulS32);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(MKLMatMulF32);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(MKLMatMulF64);
|
||||
@ -257,6 +259,8 @@ bool RegisterKnownJITSymbols() {
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF16);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF32);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF64);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulC64);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulC128);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulS32);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(ParallelForkJoin);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(ReleaseInfeedBufferAfterDequeue);
|
||||
|
||||
@ -2556,6 +2556,20 @@ std::unique_ptr<Array2D<double>> HloEvaluator::MatmulArray2D(
|
||||
lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF64);
|
||||
}
|
||||
|
||||
std::unique_ptr<Array2D<std::complex<float>>> HloEvaluator::MatmulArray2D(
|
||||
const Array2D<std::complex<float>>& lhs,
|
||||
const Array2D<std::complex<float>>& rhs) {
|
||||
return MatmulArray2DImpl<std::complex<float>>(
|
||||
lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulC64);
|
||||
}
|
||||
|
||||
std::unique_ptr<Array2D<std::complex<double>>> HloEvaluator::MatmulArray2D(
|
||||
const Array2D<std::complex<double>>& lhs,
|
||||
const Array2D<std::complex<double>>& rhs) {
|
||||
return MatmulArray2DImpl<std::complex<double>>(
|
||||
lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulC128);
|
||||
}
|
||||
|
||||
std::unique_ptr<Array2D<int32>> HloEvaluator::MatmulArray2D(
|
||||
const Array2D<int32>& lhs, const Array2D<int32>& rhs) {
|
||||
return MatmulArray2DImpl<int32>(
|
||||
|
||||
@ -164,6 +164,12 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
|
||||
const Array2D<float>& lhs, const Array2D<float>& rhs);
|
||||
static std::unique_ptr<Array2D<double>> MatmulArray2D(
|
||||
const Array2D<double>& lhs, const Array2D<double>& rhs);
|
||||
static std::unique_ptr<Array2D<std::complex<float>>> MatmulArray2D(
|
||||
const Array2D<std::complex<float>>& lhs,
|
||||
const Array2D<std::complex<float>>& rhs);
|
||||
static std::unique_ptr<Array2D<std::complex<double>>> MatmulArray2D(
|
||||
const Array2D<std::complex<double>>& lhs,
|
||||
const Array2D<std::complex<double>>& rhs);
|
||||
static std::unique_ptr<Array2D<int32>> MatmulArray2D(
|
||||
const Array2D<int32>& lhs, const Array2D<int32>& rhs);
|
||||
|
||||
|
||||
@ -416,6 +416,10 @@ XLA_TEST_P(ParametricDotTest, TestF16) { TestImpl<Eigen::half>(); }
|
||||
#endif
|
||||
XLA_TEST_P(ParametricDotTest, TestF32) { TestImpl<float>(); }
|
||||
XLA_TEST_P(ParametricDotTest, TestF64) { TestImpl<double>(); }
|
||||
XLA_TEST_P(ParametricDotTest, TestC64) { TestImpl<std::complex<float>>(); }
|
||||
#ifndef XLA_BACKEND_DOES_NOT_SUPPORT_COMPLEX128
|
||||
XLA_TEST_P(ParametricDotTest, TestC128) { TestImpl<std::complex<double>>(); }
|
||||
#endif
|
||||
XLA_TEST_P(ParametricDotTest, TestS32) { TestImpl<int32>(); }
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(DotTests, ParametricDotTest,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user