[XLA:CPU] Wire up C64/C128 matmul to Eigen

This is much faster than a naive loop. Also add some more testing now that we
can support it in the evaluator.

PiperOrigin-RevId: 313407740
Change-Id: I692de60af47e86a269ab4d121e97d2b472b7a8e3
This commit is contained in:
Benjamin Kramer 2020-05-27 09:50:04 -07:00 committed by TensorFlower Gardener
parent 68ededda03
commit a5fef39a38
11 changed files with 125 additions and 6 deletions

View File

@ -67,6 +67,10 @@ extern const char* const kEigenMatMulF32SymbolName =
"__xla_cpu_runtime_EigenMatMulF32";
extern const char* const kEigenMatMulF64SymbolName =
"__xla_cpu_runtime_EigenMatMulF64";
extern const char* const kEigenMatMulC64SymbolName =
"__xla_cpu_runtime_EigenMatMulC64";
extern const char* const kEigenMatMulC128SymbolName =
"__xla_cpu_runtime_EigenMatMulC128";
extern const char* const kEigenMatMulS32SymbolName =
"__xla_cpu_runtime_EigenMatMulS32";
extern const char* const kMKLConvF32SymbolName = "__xla_cpu_runtime_MKLConvF32";
@ -91,6 +95,10 @@ extern const char* const kEigenSingleThreadedMatMulF32SymbolName =
"__xla_cpu_runtime_EigenSingleThreadedMatMulF32";
extern const char* const kEigenSingleThreadedMatMulF64SymbolName =
"__xla_cpu_runtime_EigenSingleThreadedMatMulF64";
extern const char* const kEigenSingleThreadedMatMulC64SymbolName =
"__xla_cpu_runtime_EigenSingleThreadedMatMulC64";
extern const char* const kEigenSingleThreadedMatMulC128SymbolName =
"__xla_cpu_runtime_EigenSingleThreadedMatMulC128";
extern const char* const kEigenSingleThreadedMatMulS32SymbolName =
"__xla_cpu_runtime_EigenSingleThreadedMatMulS32";
extern const char* const kEigenSingleThreadedConvF16SymbolName =

View File

@ -46,6 +46,8 @@ namespace runtime {
extern const char* const kEigenMatMulF16SymbolName;
extern const char* const kEigenMatMulF32SymbolName;
extern const char* const kEigenMatMulF64SymbolName;
extern const char* const kEigenMatMulC64SymbolName;
extern const char* const kEigenMatMulC128SymbolName;
extern const char* const kEigenMatMulS32SymbolName;
extern const char* const kMKLConvF32SymbolName;
extern const char* const kMKLMatMulF32SymbolName;
@ -59,6 +61,8 @@ extern const char* const kEigenSingleThreadedFftSymbolName;
extern const char* const kEigenSingleThreadedMatMulF16SymbolName;
extern const char* const kEigenSingleThreadedMatMulF32SymbolName;
extern const char* const kEigenSingleThreadedMatMulF64SymbolName;
extern const char* const kEigenSingleThreadedMatMulC64SymbolName;
extern const char* const kEigenSingleThreadedMatMulC128SymbolName;
extern const char* const kEigenSingleThreadedMatMulS32SymbolName;
extern const char* const kEigenSingleThreadedConvF16SymbolName;
extern const char* const kEigenSingleThreadedConvF32SymbolName;

View File

@ -657,6 +657,8 @@ Status DotOpEmitter::EmitCallToRuntime() {
bool multi_threaded = ShouldUseMultiThreadedEigen(hlo_module_config_);
bool use_mkl_dnn = hlo_module_config_.debug_options().xla_cpu_use_mkl_dnn();
PrimitiveType type = target_array_.GetShape().element_type();
llvm::Function* function = b_->GetInsertBlock()->getParent();
llvm::Module* module = function->getParent();
llvm::Type* float_type;
const char* fn_name;
switch (type) {
@ -684,6 +686,18 @@ Status DotOpEmitter::EmitCallToRuntime() {
: runtime::kEigenSingleThreadedMatMulF64SymbolName);
float_type = b_->getDoubleTy();
break;
case C64:
fn_name = multi_threaded
? runtime::kEigenMatMulC64SymbolName
: runtime::kEigenSingleThreadedMatMulC64SymbolName;
float_type = llvm_ir::PrimitiveTypeToIrType(C64, module);
break;
case C128:
fn_name = multi_threaded
? runtime::kEigenMatMulC128SymbolName
: runtime::kEigenSingleThreadedMatMulC128SymbolName;
float_type = llvm_ir::PrimitiveTypeToIrType(C128, module);
break;
case S32:
fn_name = multi_threaded
? runtime::kEigenMatMulS32SymbolName
@ -705,9 +719,6 @@ Status DotOpEmitter::EmitCallToRuntime() {
int64_type, int64_type, int64_type, int32_type, int32_type},
/*isVarArg=*/false);
llvm::Function* function = b_->GetInsertBlock()->getParent();
llvm::Module* module = function->getParent();
llvm::FunctionCallee matmul_func =
module->getOrInsertFunction(fn_name, matmul_type);
if (auto* fn = llvm::dyn_cast<llvm::Function>(matmul_func.getCallee())) {
@ -853,9 +864,11 @@ bool AreGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape,
<< output_shape.DebugString();
switch (output_shape.element_type()) {
case F64:
case F32:
case F16:
case F32:
case F64:
case C64:
case C128:
case S32:
return IsRank2(lhs_shape) && IsRank2(rhs_shape) && IsRank2(output_shape);
default:
@ -904,7 +917,9 @@ bool CanEmitTiledLlvmIrGemm(
return false;
}
if (dot_info.result_shape.element_type() == F16) {
if (dot_info.result_shape.element_type() == F16 ||
dot_info.result_shape.element_type() == C64 ||
dot_info.result_shape.element_type() == C128) {
// TODO(sanjoy): This is probably easy to fix, but I want to keep the CL
// adding this comment NFC.
return false;

View File

@ -114,6 +114,22 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF64(
transpose_rhs);
}
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulC64(
const void* run_options_ptr, std::complex<float>* out,
std::complex<float>* lhs, std::complex<float>* rhs, int64 m, int64 n,
int64 k, int32 transpose_lhs, int32 transpose_rhs) {
MatMulDispatch<std::complex<float>>(run_options_ptr, out, lhs, rhs, m, n, k,
transpose_lhs, transpose_rhs);
}
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulC128(
const void* run_options_ptr, std::complex<double>* out,
std::complex<double>* lhs, std::complex<double>* rhs, int64 m, int64 n,
int64 k, int32 transpose_lhs, int32 transpose_rhs) {
MatMulDispatch<std::complex<double>>(run_options_ptr, out, lhs, rhs, m, n, k,
transpose_lhs, transpose_rhs);
}
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulS32(
const void* run_options_ptr, int32* out, int32* lhs, int32* rhs, int64 m,
int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) {

View File

@ -16,6 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATMUL_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATMUL_H_
#include <complex>
#include "third_party/eigen3/Eigen/Core"
#include "tensorflow/core/platform/types.h"
@ -44,6 +46,18 @@ extern void __xla_cpu_runtime_EigenMatMulF64(
tensorflow::int64 k, tensorflow::int32 transpose_lhs,
tensorflow::int32 transpose_rhs);
extern void __xla_cpu_runtime_EigenMatMulC64(
const void* run_options_ptr, std::complex<float>* out,
std::complex<float>* lhs, std::complex<float>* rhs, tensorflow::int64 m,
tensorflow::int64 n, tensorflow::int64 k, tensorflow::int32 transpose_lhs,
tensorflow::int32 transpose_rhs);
extern void __xla_cpu_runtime_EigenMatMulC128(
const void* run_options_ptr, std::complex<double>* out,
std::complex<double>* lhs, std::complex<double>* rhs, tensorflow::int64 m,
tensorflow::int64 n, tensorflow::int64 k, tensorflow::int32 transpose_lhs,
tensorflow::int32 transpose_rhs);
extern void __xla_cpu_runtime_EigenMatMulS32(
const void* /* xla::ExecutableRunOptions* */ run_options_ptr,
tensorflow::int32* out, tensorflow::int32* lhs, tensorflow::int32* rhs,

View File

@ -112,6 +112,24 @@ __xla_cpu_runtime_EigenSingleThreadedMatMulF64(const void* run_options_ptr,
transpose_lhs, transpose_rhs);
}
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void
__xla_cpu_runtime_EigenSingleThreadedMatMulC64(
const void* run_options_ptr, std::complex<float>* out,
std::complex<float>* lhs, std::complex<float>* rhs, int64 m, int64 n,
int64 k, int32 transpose_lhs, int32 transpose_rhs) {
SingleThreadedMatMulDispatch<std::complex<float>>(
run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs);
}
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void
__xla_cpu_runtime_EigenSingleThreadedMatMulC128(
const void* run_options_ptr, std::complex<double>* out,
std::complex<double>* lhs, std::complex<double>* rhs, int64 m, int64 n,
int64 k, int32 transpose_lhs, int32 transpose_rhs) {
SingleThreadedMatMulDispatch<std::complex<double>>(
run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs);
}
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void
__xla_cpu_runtime_EigenSingleThreadedMatMulS32(const void* run_options_ptr,
int32* out, int32* lhs,

View File

@ -16,6 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_MATMUL_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_MATMUL_H_
#include <complex>
#include "third_party/eigen3/Eigen/Core"
#include "tensorflow/core/platform/types.h"
@ -44,6 +46,20 @@ extern void __xla_cpu_runtime_EigenSingleThreadedMatMulF64(
tensorflow::int64 k, tensorflow::int32 transpose_lhs,
tensorflow::int32 transpose_rhs);
extern void __xla_cpu_runtime_EigenSingleThreadedMatMulC64(
const void* /* xla::ExecutableRunOptions* */ run_options_ptr,
std::complex<float>* out, std::complex<float>* lhs,
std::complex<float>* rhs, tensorflow::int64 m, tensorflow::int64 n,
tensorflow::int64 k, tensorflow::int32 transpose_lhs,
tensorflow::int32 transpose_rhs);
extern void __xla_cpu_runtime_EigenSingleThreadedMatMulC128(
const void* /* xla::ExecutableRunOptions* */ run_options_ptr,
std::complex<double>* out, std::complex<double>* lhs,
std::complex<double>* rhs, tensorflow::int64 m, tensorflow::int64 n,
tensorflow::int64 k, tensorflow::int32 transpose_lhs,
tensorflow::int32 transpose_rhs);
extern void __xla_cpu_runtime_EigenSingleThreadedMatMulS32(
const void* /* xla::ExecutableRunOptions* */ run_options_ptr,
tensorflow::int32* out, tensorflow::int32* lhs, tensorflow::int32* rhs,

View File

@ -246,6 +246,8 @@ bool RegisterKnownJITSymbols() {
REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF16);
REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF32);
REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF64);
REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulC64);
REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulC128);
REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulS32);
REGISTER_CPU_RUNTIME_SYMBOL(MKLMatMulF32);
REGISTER_CPU_RUNTIME_SYMBOL(MKLMatMulF64);
@ -257,6 +259,8 @@ bool RegisterKnownJITSymbols() {
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF16);
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF32);
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF64);
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulC64);
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulC128);
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulS32);
REGISTER_CPU_RUNTIME_SYMBOL(ParallelForkJoin);
REGISTER_CPU_RUNTIME_SYMBOL(ReleaseInfeedBufferAfterDequeue);

View File

@ -2556,6 +2556,20 @@ std::unique_ptr<Array2D<double>> HloEvaluator::MatmulArray2D(
lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF64);
}
std::unique_ptr<Array2D<std::complex<float>>> HloEvaluator::MatmulArray2D(
const Array2D<std::complex<float>>& lhs,
const Array2D<std::complex<float>>& rhs) {
return MatmulArray2DImpl<std::complex<float>>(
lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulC64);
}
std::unique_ptr<Array2D<std::complex<double>>> HloEvaluator::MatmulArray2D(
const Array2D<std::complex<double>>& lhs,
const Array2D<std::complex<double>>& rhs) {
return MatmulArray2DImpl<std::complex<double>>(
lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulC128);
}
std::unique_ptr<Array2D<int32>> HloEvaluator::MatmulArray2D(
const Array2D<int32>& lhs, const Array2D<int32>& rhs) {
return MatmulArray2DImpl<int32>(

View File

@ -164,6 +164,12 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
const Array2D<float>& lhs, const Array2D<float>& rhs);
static std::unique_ptr<Array2D<double>> MatmulArray2D(
const Array2D<double>& lhs, const Array2D<double>& rhs);
static std::unique_ptr<Array2D<std::complex<float>>> MatmulArray2D(
const Array2D<std::complex<float>>& lhs,
const Array2D<std::complex<float>>& rhs);
static std::unique_ptr<Array2D<std::complex<double>>> MatmulArray2D(
const Array2D<std::complex<double>>& lhs,
const Array2D<std::complex<double>>& rhs);
static std::unique_ptr<Array2D<int32>> MatmulArray2D(
const Array2D<int32>& lhs, const Array2D<int32>& rhs);

View File

@ -416,6 +416,10 @@ XLA_TEST_P(ParametricDotTest, TestF16) { TestImpl<Eigen::half>(); }
#endif
XLA_TEST_P(ParametricDotTest, TestF32) { TestImpl<float>(); }
XLA_TEST_P(ParametricDotTest, TestF64) { TestImpl<double>(); }
XLA_TEST_P(ParametricDotTest, TestC64) { TestImpl<std::complex<float>>(); }
#ifndef XLA_BACKEND_DOES_NOT_SUPPORT_COMPLEX128
XLA_TEST_P(ParametricDotTest, TestC128) { TestImpl<std::complex<double>>(); }
#endif
XLA_TEST_P(ParametricDotTest, TestS32) { TestImpl<int32>(); }
INSTANTIATE_TEST_CASE_P(DotTests, ParametricDotTest,