diff --git a/tensorflow/core/common_runtime/mkl_layout_pass.cc b/tensorflow/core/common_runtime/mkl_layout_pass.cc index 2941845a604..55355363106 100644 --- a/tensorflow/core/common_runtime/mkl_layout_pass.cc +++ b/tensorflow/core/common_runtime/mkl_layout_pass.cc @@ -499,7 +499,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { CopyAttrsAll, LrnGradRewrite, kRewriteForLayoutPropagation}); rinfo_.push_back({csinfo_.matmul, mkl_op_registry::GetMklOpName(csinfo_.matmul), - CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange}); + CopyAttrsAll, MatMulRewrite, kRewriteForOpNameChange}); rinfo_.push_back( {csinfo_.leakyrelu, mkl_op_registry::GetMklOpName(csinfo_.leakyrelu), CopyAttrsAll, LeakyReluRewrite, kRewriteForLayoutPropagation}); @@ -1473,6 +1473,16 @@ class MklLayoutRewritePass : public GraphOptimizationPass { return false; } + static bool MatMulRewrite(const Node* n) { + DataType T; + GetNodeAttr(n->def(), "T", &T); + if ((T == DT_FLOAT) || (T == DT_BFLOAT16)) { + VLOG(2) << "Rewriting MatMul to _MklMatMul"; + return true; + } + return false; + } + static bool DequantizeRewrite(const Node* n) { DCHECK(n); Node* input = nullptr; diff --git a/tensorflow/core/kernels/mkl_matmul_op.cc b/tensorflow/core/kernels/mkl_matmul_op.cc index 3a7c864d10e..83785af8910 100644 --- a/tensorflow/core/kernels/mkl_matmul_op.cc +++ b/tensorflow/core/kernels/mkl_matmul_op.cc @@ -31,13 +31,7 @@ limitations under the License. #include "tensorflow/core/kernels/fill_functor.h" #include "tensorflow/core/kernels/mkl_matmul_ops_common.h" #include "tensorflow/core/util/mkl_util.h" - -// This header file is part of MKL ML, need equivalent file in MKL DNN -#ifndef INTEL_MKL_DNN_ONLY -#include "mkl_cblas.h" -#endif - -#include "mkldnn.h" +#include "mkldnn.hpp" namespace tensorflow { @@ -157,21 +151,11 @@ class MklMatMulOp : public OpKernel { // 1.0 and 0.0 respectively. const float alpha = 1.0f; const float beta = 0.0f; -#if defined(INTEL_MKL_DNN_ONLY) - const char* const ftrans[] = {"N", "T", "C"}; - int index_transa = transa ? 1 : 0; - int index_transb = transb ? 1 : 0; - VLOG(2) << "MKL DNN SGEMM called"; - // MKL DNN only supports the Fortran api and requires column major while - // Tensorflow uses row major so we reverse the order A and B - mkldnn_sgemm(ftrans[index_transb], ftrans[index_transa], &n, &m, &k, &alpha, - b, &ldb, a, &lda, &beta, c, &ldc); -#else - // MKL ML binary uses CBLAS API - cblas_sgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans, - transb ? CblasTrans : CblasNoTrans, m, n, k, alpha, a, lda, b, - ldb, beta, c, ldc); -#endif + char char_transa = transa ? 'T' : 'N'; + char char_transb = transb ? 'T' : 'N'; + VLOG(2) << "MKL DNN SGEMM CALLED"; + dnnl_sgemm(char_transa, char_transb, m, n, k, alpha, + a, lda, b, ldb, beta, c, ldc); } #ifdef ENABLE_INTEL_MKL_BFLOAT16 @@ -205,53 +189,6 @@ class MklMatMulOp : public OpKernel { FloatToBFloat16(c_float.flat().data(), c, c_float.NumElements()); } #endif // ENABLE_INTEL_MKL_BFLOAT16 - -// MKL-DNN only supports SGEMM and bfloat16-GEMM. -#ifndef INTEL_MKL_DNN_ONLY - - // Matrix-Matrix Multiplication with FP64 tensors. For detailed info about - // parameters, look at FP32 function description. - void MklBlasGemm(OpKernelContext* ctx, bool transa, bool transb, const int m, - const int n, const int k, const double* a, const int lda, - const double* b, const int ldb, double* c, const int ldc) { - const double alpha = 1.0; - const double beta = 0.0; - cblas_dgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans, - transb ? CblasTrans : CblasNoTrans, m, n, k, alpha, a, lda, b, - ldb, beta, c, ldc); - } - - // Matrix-Matrix Multiplication with Complex64 (std::complex) tensors. - // For detailed info about parameters, look at FP32 function description. - void MklBlasGemm(OpKernelContext* ctx, bool transa, bool transb, const int m, - const int n, const int k, const complex64* a, const int lda, - const complex64* b, const int ldb, complex64* c, - int const ldc) { - const MKL_Complex8 alpha = {1.0f, 0.0f}; - const MKL_Complex8 beta = {0.0f, 0.0f}; - cblas_cgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans, - transb ? CblasTrans : CblasNoTrans, m, n, k, &alpha, - reinterpret_cast(a), lda, - reinterpret_cast(b), ldb, &beta, - reinterpret_cast(c), ldc); - } - - // Matrix-Matrix Multiplication with Complex128 (std::complex) - // tensors. For detailed info about parameters, look at FP32 function - // description. - void MklBlasGemm(OpKernelContext* ctx, bool transa, bool transb, const int m, - const int n, const int k, const complex128* a, const int lda, - const complex128* b, const int ldb, complex128* c, - const int ldc) { - const MKL_Complex16 alpha = {1.0, 0.0}; - const MKL_Complex16 beta = {0.0, 0.0}; - cblas_zgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans, - transb ? CblasTrans : CblasNoTrans, m, n, k, &alpha, - reinterpret_cast(a), lda, - reinterpret_cast(b), ldb, &beta, - reinterpret_cast(c), ldc); - } -#endif // !INTEL_MKL_DNN_ONLY }; #define REGISTER_CPU(T) \ @@ -269,13 +206,6 @@ TF_CALL_float(REGISTER_CPU); #ifdef ENABLE_INTEL_MKL_BFLOAT16 TF_CALL_bfloat16(REGISTER_CPU); #endif // ENABLE_INTEL_MKL_BFLOAT16 - -#ifndef INTEL_MKL_DNN_ONLY -TF_CALL_double(REGISTER_CPU); -TF_CALL_complex64(REGISTER_CPU); -TF_CALL_complex128(REGISTER_CPU); -#endif // !INTEL_MKL_DNN_ONLY #endif // ENABLE_MKL - } // namespace tensorflow #endif // INTEL_MKL diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index 7ac003379d4..d00731f223a 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -936,7 +936,7 @@ REGISTER_OP("_MklMatMul") .Output("product: T") .Attr("transpose_a: bool = false") .Attr("transpose_b: bool = false") - .Attr("T: {bfloat16, float, double, complex64, complex128}") + .Attr("T: {bfloat16, float}") .SetShapeFn(shape_inference::MatMulShape); #endif // INTEL_MKL