Merge pull request #39325 from Intel-tensorflow:sriniva2/dnnl_sgemm
PiperOrigin-RevId: 312846725 Change-Id: I8204aea5531380ab1f916c4053f4049d01a788d5
This commit is contained in:
commit
6333fef206
|
@ -506,7 +506,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||
CopyAttrsAll, LrnGradRewrite, kRewriteForLayoutPropagation});
|
||||
rinfo_.push_back({csinfo_.matmul,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.matmul),
|
||||
CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange});
|
||||
CopyAttrsAll, MatMulRewrite, kRewriteForOpNameChange});
|
||||
rinfo_.push_back(
|
||||
{csinfo_.leakyrelu, mkl_op_registry::GetMklOpName(csinfo_.leakyrelu),
|
||||
CopyAttrsAll, LeakyReluRewrite, kRewriteForLayoutPropagation});
|
||||
|
@ -1482,6 +1482,16 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||
return false;
|
||||
}
|
||||
|
||||
static bool MatMulRewrite(const Node* n) {
|
||||
DataType T;
|
||||
GetNodeAttr(n->def(), "T", &T);
|
||||
if ((T == DT_FLOAT) || (T == DT_BFLOAT16)) {
|
||||
VLOG(2) << "Rewriting MatMul to _MklMatMul";
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool DequantizeRewrite(const Node* n) {
|
||||
DCHECK(n);
|
||||
Node* input = nullptr;
|
||||
|
|
|
@ -25,6 +25,7 @@ limitations under the License.
|
|||
|
||||
#if defined(INTEL_MKL)
|
||||
|
||||
#include "mkldnn.hpp"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
|
@ -32,13 +33,6 @@ limitations under the License.
|
|||
#include "tensorflow/core/kernels/mkl_matmul_ops_common.h"
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
|
||||
// This header file is part of MKL ML, need equivalent file in MKL DNN
|
||||
#ifndef INTEL_MKL_DNN_ONLY
|
||||
#include "mkl_cblas.h"
|
||||
#endif
|
||||
|
||||
#include "mkldnn.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
|
@ -157,21 +151,11 @@ class MklMatMulOp : public OpKernel {
|
|||
// 1.0 and 0.0 respectively.
|
||||
const float alpha = 1.0f;
|
||||
const float beta = 0.0f;
|
||||
#if defined(INTEL_MKL_DNN_ONLY)
|
||||
const char* const ftrans[] = {"N", "T", "C"};
|
||||
int index_transa = transa ? 1 : 0;
|
||||
int index_transb = transb ? 1 : 0;
|
||||
VLOG(2) << "MKL DNN SGEMM called";
|
||||
// MKL DNN only supports the Fortran api and requires column major while
|
||||
// Tensorflow uses row major so we reverse the order A and B
|
||||
mkldnn_sgemm(ftrans[index_transb], ftrans[index_transa], &n, &m, &k, &alpha,
|
||||
b, &ldb, a, &lda, &beta, c, &ldc);
|
||||
#else
|
||||
// MKL ML binary uses CBLAS API
|
||||
cblas_sgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans,
|
||||
transb ? CblasTrans : CblasNoTrans, m, n, k, alpha, a, lda, b,
|
||||
ldb, beta, c, ldc);
|
||||
#endif
|
||||
char char_transa = transa ? 'T' : 'N';
|
||||
char char_transb = transb ? 'T' : 'N';
|
||||
VLOG(2) << "MKL DNN SGEMM CALLED";
|
||||
dnnl_sgemm(char_transa, char_transb, m, n, k, alpha, a, lda, b, ldb, beta,
|
||||
c, ldc);
|
||||
}
|
||||
|
||||
#ifdef ENABLE_INTEL_MKL_BFLOAT16
|
||||
|
@ -205,53 +189,6 @@ class MklMatMulOp : public OpKernel {
|
|||
FloatToBFloat16(c_float.flat<float>().data(), c, c_float.NumElements());
|
||||
}
|
||||
#endif // ENABLE_INTEL_MKL_BFLOAT16
|
||||
|
||||
// MKL-DNN only supports SGEMM and bfloat16-GEMM.
|
||||
#ifndef INTEL_MKL_DNN_ONLY
|
||||
|
||||
// Matrix-Matrix Multiplication with FP64 tensors. For detailed info about
|
||||
// parameters, look at FP32 function description.
|
||||
void MklBlasGemm(OpKernelContext* ctx, bool transa, bool transb, const int m,
|
||||
const int n, const int k, const double* a, const int lda,
|
||||
const double* b, const int ldb, double* c, const int ldc) {
|
||||
const double alpha = 1.0;
|
||||
const double beta = 0.0;
|
||||
cblas_dgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans,
|
||||
transb ? CblasTrans : CblasNoTrans, m, n, k, alpha, a, lda, b,
|
||||
ldb, beta, c, ldc);
|
||||
}
|
||||
|
||||
// Matrix-Matrix Multiplication with Complex64 (std::complex<float>) tensors.
|
||||
// For detailed info about parameters, look at FP32 function description.
|
||||
void MklBlasGemm(OpKernelContext* ctx, bool transa, bool transb, const int m,
|
||||
const int n, const int k, const complex64* a, const int lda,
|
||||
const complex64* b, const int ldb, complex64* c,
|
||||
int const ldc) {
|
||||
const MKL_Complex8 alpha = {1.0f, 0.0f};
|
||||
const MKL_Complex8 beta = {0.0f, 0.0f};
|
||||
cblas_cgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans,
|
||||
transb ? CblasTrans : CblasNoTrans, m, n, k, &alpha,
|
||||
reinterpret_cast<const MKL_Complex8*>(a), lda,
|
||||
reinterpret_cast<const MKL_Complex8*>(b), ldb, &beta,
|
||||
reinterpret_cast<MKL_Complex8*>(c), ldc);
|
||||
}
|
||||
|
||||
// Matrix-Matrix Multiplication with Complex128 (std::complex<double>)
|
||||
// tensors. For detailed info about parameters, look at FP32 function
|
||||
// description.
|
||||
void MklBlasGemm(OpKernelContext* ctx, bool transa, bool transb, const int m,
|
||||
const int n, const int k, const complex128* a, const int lda,
|
||||
const complex128* b, const int ldb, complex128* c,
|
||||
const int ldc) {
|
||||
const MKL_Complex16 alpha = {1.0, 0.0};
|
||||
const MKL_Complex16 beta = {0.0, 0.0};
|
||||
cblas_zgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans,
|
||||
transb ? CblasTrans : CblasNoTrans, m, n, k, &alpha,
|
||||
reinterpret_cast<const MKL_Complex16*>(a), lda,
|
||||
reinterpret_cast<const MKL_Complex16*>(b), ldb, &beta,
|
||||
reinterpret_cast<MKL_Complex16*>(c), ldc);
|
||||
}
|
||||
#endif // !INTEL_MKL_DNN_ONLY
|
||||
};
|
||||
|
||||
#define REGISTER_CPU(T) \
|
||||
|
@ -269,13 +206,6 @@ TF_CALL_float(REGISTER_CPU);
|
|||
#ifdef ENABLE_INTEL_MKL_BFLOAT16
|
||||
TF_CALL_bfloat16(REGISTER_CPU);
|
||||
#endif // ENABLE_INTEL_MKL_BFLOAT16
|
||||
|
||||
#ifndef INTEL_MKL_DNN_ONLY
|
||||
TF_CALL_double(REGISTER_CPU);
|
||||
TF_CALL_complex64(REGISTER_CPU);
|
||||
TF_CALL_complex128(REGISTER_CPU);
|
||||
#endif // !INTEL_MKL_DNN_ONLY
|
||||
#endif // ENABLE_MKL
|
||||
|
||||
} // namespace tensorflow
|
||||
#endif // INTEL_MKL
|
||||
|
|
|
@ -936,7 +936,7 @@ REGISTER_OP("_MklMatMul")
|
|||
.Output("product: T")
|
||||
.Attr("transpose_a: bool = false")
|
||||
.Attr("transpose_b: bool = false")
|
||||
.Attr("T: {bfloat16, float, double, complex64, complex128}")
|
||||
.Attr("T: {bfloat16, float}")
|
||||
.SetShapeFn(shape_inference::MatMulShape);
|
||||
#endif // INTEL_MKL
|
||||
|
||||
|
|
Loading…
Reference in New Issue