Merge pull request #39915 from Intel-tensorflow:sriniva2/sgemm_fix
PiperOrigin-RevId: 313831498 Change-Id: I0b68134e7c0389506cb97259eac79a4301a3e9d3
This commit is contained in:
commit
8c8bd07458
@ -25,7 +25,11 @@ limitations under the License.
|
||||
|
||||
#if defined(INTEL_MKL)
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
#include "mkldnn.hpp"
|
||||
#else
|
||||
#include "mkl_cblas.h"
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
@ -151,11 +155,18 @@ class MklMatMulOp : public OpKernel {
|
||||
// 1.0 and 0.0 respectively.
|
||||
const float alpha = 1.0f;
|
||||
const float beta = 0.0f;
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
char char_transa = transa ? 'T' : 'N';
|
||||
char char_transb = transb ? 'T' : 'N';
|
||||
VLOG(2) << "MKL DNN SGEMM CALLED";
|
||||
dnnl_sgemm(char_transa, char_transb, m, n, k, alpha, a, lda, b, ldb, beta,
|
||||
c, ldc);
|
||||
#else
|
||||
// TODO(intel-tf): Remove this after TF2.3 fork.
|
||||
cblas_sgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans,
|
||||
transb ? CblasTrans : CblasNoTrans, m, n, k, alpha, a, lda, b,
|
||||
ldb, beta, c, ldc);
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
}
|
||||
|
||||
#ifdef ENABLE_INTEL_MKL_BFLOAT16
|
||||
|
Loading…
x
Reference in New Issue
Block a user