Merge pull request #39915 from Intel-tensorflow:sriniva2/sgemm_fix
PiperOrigin-RevId: 313831498 Change-Id: I0b68134e7c0389506cb97259eac79a4301a3e9d3
This commit is contained in:
commit
8c8bd07458
@ -25,7 +25,11 @@ limitations under the License.
|
|||||||
|
|
||||||
#if defined(INTEL_MKL)
|
#if defined(INTEL_MKL)
|
||||||
|
|
||||||
|
#ifdef ENABLE_MKLDNN_V1
|
||||||
#include "mkldnn.hpp"
|
#include "mkldnn.hpp"
|
||||||
|
#else
|
||||||
|
#include "mkl_cblas.h"
|
||||||
|
#endif // ENABLE_MKLDNN_V1
|
||||||
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
@ -151,11 +155,18 @@ class MklMatMulOp : public OpKernel {
|
|||||||
// 1.0 and 0.0 respectively.
|
// 1.0 and 0.0 respectively.
|
||||||
const float alpha = 1.0f;
|
const float alpha = 1.0f;
|
||||||
const float beta = 0.0f;
|
const float beta = 0.0f;
|
||||||
|
#ifdef ENABLE_MKLDNN_V1
|
||||||
char char_transa = transa ? 'T' : 'N';
|
char char_transa = transa ? 'T' : 'N';
|
||||||
char char_transb = transb ? 'T' : 'N';
|
char char_transb = transb ? 'T' : 'N';
|
||||||
VLOG(2) << "MKL DNN SGEMM CALLED";
|
VLOG(2) << "MKL DNN SGEMM CALLED";
|
||||||
dnnl_sgemm(char_transa, char_transb, m, n, k, alpha, a, lda, b, ldb, beta,
|
dnnl_sgemm(char_transa, char_transb, m, n, k, alpha, a, lda, b, ldb, beta,
|
||||||
c, ldc);
|
c, ldc);
|
||||||
|
#else
|
||||||
|
// TODO(intel-tf): Remove this after TF2.3 fork.
|
||||||
|
cblas_sgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans,
|
||||||
|
transb ? CblasTrans : CblasNoTrans, m, n, k, alpha, a, lda, b,
|
||||||
|
ldb, beta, c, ldc);
|
||||||
|
#endif // ENABLE_MKLDNN_V1
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef ENABLE_INTEL_MKL_BFLOAT16
|
#ifdef ENABLE_INTEL_MKL_BFLOAT16
|
||||||
|
Loading…
x
Reference in New Issue
Block a user