diff --git a/tensorflow/core/kernels/mkl_matmul_op.cc b/tensorflow/core/kernels/mkl_matmul_op.cc index 687f67f6283..86193901c96 100644 --- a/tensorflow/core/kernels/mkl_matmul_op.cc +++ b/tensorflow/core/kernels/mkl_matmul_op.cc @@ -25,7 +25,11 @@ limitations under the License. #if defined(INTEL_MKL) +#ifdef ENABLE_MKLDNN_V1 #include "mkldnn.hpp" +#else +#include "mkl_cblas.h" +#endif // ENABLE_MKLDNN_V1 #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -151,11 +155,18 @@ class MklMatMulOp : public OpKernel { // 1.0 and 0.0 respectively. const float alpha = 1.0f; const float beta = 0.0f; +#ifdef ENABLE_MKLDNN_V1 char char_transa = transa ? 'T' : 'N'; char char_transb = transb ? 'T' : 'N'; VLOG(2) << "MKL DNN SGEMM CALLED"; dnnl_sgemm(char_transa, char_transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +#else + // TODO(intel-tf): Remove this after TF2.3 fork. + cblas_sgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans, + transb ? CblasTrans : CblasNoTrans, m, n, k, alpha, a, lda, b, + ldb, beta, c, ldc); +#endif // ENABLE_MKLDNN_V1 } #ifdef ENABLE_INTEL_MKL_BFLOAT16