From 19e210f87c7e28528dd063c77b565b598d31e1a1 Mon Sep 17 00:00:00 2001 From: Srinivasan Narayanamoorthy Date: Wed, 27 May 2020 10:53:42 -0700 Subject: [PATCH] sgemm fix. --- tensorflow/core/kernels/mkl_matmul_op.cc | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/tensorflow/core/kernels/mkl_matmul_op.cc b/tensorflow/core/kernels/mkl_matmul_op.cc index 3eccf97f53c..7e76e7fd6ca 100644 --- a/tensorflow/core/kernels/mkl_matmul_op.cc +++ b/tensorflow/core/kernels/mkl_matmul_op.cc @@ -25,7 +25,11 @@ limitations under the License. #if defined(INTEL_MKL) +#ifdef ENABLE_MKLDNN_V1 #include "mkldnn.hpp" +#else +#include "mkl_cblas.h" +#endif // ENABLE_MKLDNN_V1 #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -58,11 +62,11 @@ class MklMatMulOp : public OpKernel { dim_pair[0].first = transpose_a_ ? 0 : 1; dim_pair[0].second = transpose_b_ ? 1 : 0; - OP_REQUIRES( - ctx, a.dim_size(dim_pair[0].first) == b.dim_size(dim_pair[0].second), - errors::InvalidArgument( - "Matrix size-incompatible: In[0]: ", a.shape().DebugString(), - ", In[1]: ", b.shape().DebugString())); + OP_REQUIRES(ctx, + a.dim_size(dim_pair[0].first) == b.dim_size(dim_pair[0].second), + errors::InvalidArgument("Matrix size-incompatible: In[0]: ", + a.shape().DebugString(), ", In[1]: ", + b.shape().DebugString())); int a_dim_remaining = 1 - dim_pair[0].first; int b_dim_remaining = 1 - dim_pair[0].second; TensorShape out_shape( @@ -151,11 +155,18 @@ class MklMatMulOp : public OpKernel { // 1.0 and 0.0 respectively. const float alpha = 1.0f; const float beta = 0.0f; +#ifdef ENABLE_MKLDNN_V1 char char_transa = transa ? 'T' : 'N'; char char_transb = transb ? 'T' : 'N'; VLOG(2) << "MKL DNN SGEMM CALLED"; dnnl_sgemm(char_transa, char_transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +#else + // TODO(intel-tf): Remove this after TF2.3 fork. + cblas_sgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans, + transb ? CblasTrans : CblasNoTrans, m, n, k, alpha, a, lda, b, + ldb, beta, c, ldc); +#endif // ENABLE_MKLDNN_V1 } #ifdef ENABLE_INTEL_MKL_BFLOAT16