Merge pull request #39915 from Intel-tensorflow:sriniva2/sgemm_fix

PiperOrigin-RevId: 313831498
Change-Id: I0b68134e7c0389506cb97259eac79a4301a3e9d3
This commit is contained in:
TensorFlower Gardener 2020-05-29 13:57:42 -07:00
commit 8c8bd07458

View File

@ -25,7 +25,11 @@ limitations under the License.
#if defined(INTEL_MKL)
#ifdef ENABLE_MKLDNN_V1
#include "mkldnn.hpp"
#else
#include "mkl_cblas.h"
#endif // ENABLE_MKLDNN_V1
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
@ -151,11 +155,18 @@ class MklMatMulOp : public OpKernel {
// 1.0 and 0.0 respectively.
const float alpha = 1.0f;
const float beta = 0.0f;
#ifdef ENABLE_MKLDNN_V1
char char_transa = transa ? 'T' : 'N';
char char_transb = transb ? 'T' : 'N';
VLOG(2) << "MKL DNN SGEMM CALLED";
dnnl_sgemm(char_transa, char_transb, m, n, k, alpha, a, lda, b, ldb, beta,
c, ldc);
#else
// TODO(intel-tf): Remove this after TF2.3 fork.
cblas_sgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans,
transb ? CblasTrans : CblasNoTrans, m, n, k, alpha, a, lda, b,
ldb, beta, c, ldc);
#endif // ENABLE_MKLDNN_V1
}
#ifdef ENABLE_INTEL_MKL_BFLOAT16