Merge pull request #39325 from Intel-tensorflow:sriniva2/dnnl_sgemm
PiperOrigin-RevId: 312846725 Change-Id: I8204aea5531380ab1f916c4053f4049d01a788d5
This commit is contained in:
commit
6333fef206
|
@ -506,7 +506,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||||
CopyAttrsAll, LrnGradRewrite, kRewriteForLayoutPropagation});
|
CopyAttrsAll, LrnGradRewrite, kRewriteForLayoutPropagation});
|
||||||
rinfo_.push_back({csinfo_.matmul,
|
rinfo_.push_back({csinfo_.matmul,
|
||||||
mkl_op_registry::GetMklOpName(csinfo_.matmul),
|
mkl_op_registry::GetMklOpName(csinfo_.matmul),
|
||||||
CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange});
|
CopyAttrsAll, MatMulRewrite, kRewriteForOpNameChange});
|
||||||
rinfo_.push_back(
|
rinfo_.push_back(
|
||||||
{csinfo_.leakyrelu, mkl_op_registry::GetMklOpName(csinfo_.leakyrelu),
|
{csinfo_.leakyrelu, mkl_op_registry::GetMklOpName(csinfo_.leakyrelu),
|
||||||
CopyAttrsAll, LeakyReluRewrite, kRewriteForLayoutPropagation});
|
CopyAttrsAll, LeakyReluRewrite, kRewriteForLayoutPropagation});
|
||||||
|
@ -1482,6 +1482,16 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool MatMulRewrite(const Node* n) {
|
||||||
|
DataType T;
|
||||||
|
GetNodeAttr(n->def(), "T", &T);
|
||||||
|
if ((T == DT_FLOAT) || (T == DT_BFLOAT16)) {
|
||||||
|
VLOG(2) << "Rewriting MatMul to _MklMatMul";
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
static bool DequantizeRewrite(const Node* n) {
|
static bool DequantizeRewrite(const Node* n) {
|
||||||
DCHECK(n);
|
DCHECK(n);
|
||||||
Node* input = nullptr;
|
Node* input = nullptr;
|
||||||
|
|
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||||
|
|
||||||
#if defined(INTEL_MKL)
|
#if defined(INTEL_MKL)
|
||||||
|
|
||||||
|
#include "mkldnn.hpp"
|
||||||
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
|
@ -32,13 +33,6 @@ limitations under the License.
|
||||||
#include "tensorflow/core/kernels/mkl_matmul_ops_common.h"
|
#include "tensorflow/core/kernels/mkl_matmul_ops_common.h"
|
||||||
#include "tensorflow/core/util/mkl_util.h"
|
#include "tensorflow/core/util/mkl_util.h"
|
||||||
|
|
||||||
// This header file is part of MKL ML, need equivalent file in MKL DNN
|
|
||||||
#ifndef INTEL_MKL_DNN_ONLY
|
|
||||||
#include "mkl_cblas.h"
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#include "mkldnn.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||||
|
@ -157,21 +151,11 @@ class MklMatMulOp : public OpKernel {
|
||||||
// 1.0 and 0.0 respectively.
|
// 1.0 and 0.0 respectively.
|
||||||
const float alpha = 1.0f;
|
const float alpha = 1.0f;
|
||||||
const float beta = 0.0f;
|
const float beta = 0.0f;
|
||||||
#if defined(INTEL_MKL_DNN_ONLY)
|
char char_transa = transa ? 'T' : 'N';
|
||||||
const char* const ftrans[] = {"N", "T", "C"};
|
char char_transb = transb ? 'T' : 'N';
|
||||||
int index_transa = transa ? 1 : 0;
|
VLOG(2) << "MKL DNN SGEMM CALLED";
|
||||||
int index_transb = transb ? 1 : 0;
|
dnnl_sgemm(char_transa, char_transb, m, n, k, alpha, a, lda, b, ldb, beta,
|
||||||
VLOG(2) << "MKL DNN SGEMM called";
|
c, ldc);
|
||||||
// MKL DNN only supports the Fortran api and requires column major while
|
|
||||||
// Tensorflow uses row major so we reverse the order A and B
|
|
||||||
mkldnn_sgemm(ftrans[index_transb], ftrans[index_transa], &n, &m, &k, &alpha,
|
|
||||||
b, &ldb, a, &lda, &beta, c, &ldc);
|
|
||||||
#else
|
|
||||||
// MKL ML binary uses CBLAS API
|
|
||||||
cblas_sgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans,
|
|
||||||
transb ? CblasTrans : CblasNoTrans, m, n, k, alpha, a, lda, b,
|
|
||||||
ldb, beta, c, ldc);
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef ENABLE_INTEL_MKL_BFLOAT16
|
#ifdef ENABLE_INTEL_MKL_BFLOAT16
|
||||||
|
@ -205,53 +189,6 @@ class MklMatMulOp : public OpKernel {
|
||||||
FloatToBFloat16(c_float.flat<float>().data(), c, c_float.NumElements());
|
FloatToBFloat16(c_float.flat<float>().data(), c, c_float.NumElements());
|
||||||
}
|
}
|
||||||
#endif // ENABLE_INTEL_MKL_BFLOAT16
|
#endif // ENABLE_INTEL_MKL_BFLOAT16
|
||||||
|
|
||||||
// MKL-DNN only supports SGEMM and bfloat16-GEMM.
|
|
||||||
#ifndef INTEL_MKL_DNN_ONLY
|
|
||||||
|
|
||||||
// Matrix-Matrix Multiplication with FP64 tensors. For detailed info about
|
|
||||||
// parameters, look at FP32 function description.
|
|
||||||
void MklBlasGemm(OpKernelContext* ctx, bool transa, bool transb, const int m,
|
|
||||||
const int n, const int k, const double* a, const int lda,
|
|
||||||
const double* b, const int ldb, double* c, const int ldc) {
|
|
||||||
const double alpha = 1.0;
|
|
||||||
const double beta = 0.0;
|
|
||||||
cblas_dgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans,
|
|
||||||
transb ? CblasTrans : CblasNoTrans, m, n, k, alpha, a, lda, b,
|
|
||||||
ldb, beta, c, ldc);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Matrix-Matrix Multiplication with Complex64 (std::complex<float>) tensors.
|
|
||||||
// For detailed info about parameters, look at FP32 function description.
|
|
||||||
void MklBlasGemm(OpKernelContext* ctx, bool transa, bool transb, const int m,
|
|
||||||
const int n, const int k, const complex64* a, const int lda,
|
|
||||||
const complex64* b, const int ldb, complex64* c,
|
|
||||||
int const ldc) {
|
|
||||||
const MKL_Complex8 alpha = {1.0f, 0.0f};
|
|
||||||
const MKL_Complex8 beta = {0.0f, 0.0f};
|
|
||||||
cblas_cgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans,
|
|
||||||
transb ? CblasTrans : CblasNoTrans, m, n, k, &alpha,
|
|
||||||
reinterpret_cast<const MKL_Complex8*>(a), lda,
|
|
||||||
reinterpret_cast<const MKL_Complex8*>(b), ldb, &beta,
|
|
||||||
reinterpret_cast<MKL_Complex8*>(c), ldc);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Matrix-Matrix Multiplication with Complex128 (std::complex<double>)
|
|
||||||
// tensors. For detailed info about parameters, look at FP32 function
|
|
||||||
// description.
|
|
||||||
void MklBlasGemm(OpKernelContext* ctx, bool transa, bool transb, const int m,
|
|
||||||
const int n, const int k, const complex128* a, const int lda,
|
|
||||||
const complex128* b, const int ldb, complex128* c,
|
|
||||||
const int ldc) {
|
|
||||||
const MKL_Complex16 alpha = {1.0, 0.0};
|
|
||||||
const MKL_Complex16 beta = {0.0, 0.0};
|
|
||||||
cblas_zgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans,
|
|
||||||
transb ? CblasTrans : CblasNoTrans, m, n, k, &alpha,
|
|
||||||
reinterpret_cast<const MKL_Complex16*>(a), lda,
|
|
||||||
reinterpret_cast<const MKL_Complex16*>(b), ldb, &beta,
|
|
||||||
reinterpret_cast<MKL_Complex16*>(c), ldc);
|
|
||||||
}
|
|
||||||
#endif // !INTEL_MKL_DNN_ONLY
|
|
||||||
};
|
};
|
||||||
|
|
||||||
#define REGISTER_CPU(T) \
|
#define REGISTER_CPU(T) \
|
||||||
|
@ -269,13 +206,6 @@ TF_CALL_float(REGISTER_CPU);
|
||||||
#ifdef ENABLE_INTEL_MKL_BFLOAT16
|
#ifdef ENABLE_INTEL_MKL_BFLOAT16
|
||||||
TF_CALL_bfloat16(REGISTER_CPU);
|
TF_CALL_bfloat16(REGISTER_CPU);
|
||||||
#endif // ENABLE_INTEL_MKL_BFLOAT16
|
#endif // ENABLE_INTEL_MKL_BFLOAT16
|
||||||
|
|
||||||
#ifndef INTEL_MKL_DNN_ONLY
|
|
||||||
TF_CALL_double(REGISTER_CPU);
|
|
||||||
TF_CALL_complex64(REGISTER_CPU);
|
|
||||||
TF_CALL_complex128(REGISTER_CPU);
|
|
||||||
#endif // !INTEL_MKL_DNN_ONLY
|
|
||||||
#endif // ENABLE_MKL
|
#endif // ENABLE_MKL
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
#endif // INTEL_MKL
|
#endif // INTEL_MKL
|
||||||
|
|
|
@ -936,7 +936,7 @@ REGISTER_OP("_MklMatMul")
|
||||||
.Output("product: T")
|
.Output("product: T")
|
||||||
.Attr("transpose_a: bool = false")
|
.Attr("transpose_a: bool = false")
|
||||||
.Attr("transpose_b: bool = false")
|
.Attr("transpose_b: bool = false")
|
||||||
.Attr("T: {bfloat16, float, double, complex64, complex128}")
|
.Attr("T: {bfloat16, float}")
|
||||||
.SetShapeFn(shape_inference::MatMulShape);
|
.SetShapeFn(shape_inference::MatMulShape);
|
||||||
#endif // INTEL_MKL
|
#endif // INTEL_MKL
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue