sgemm fix.

This commit is contained in:
Srinivasan Narayanamoorthy 2020-05-27 10:53:42 -07:00
parent be46769cee
commit 19e210f87c

View File

@ -25,7 +25,11 @@ limitations under the License.
#if defined(INTEL_MKL)
#ifdef ENABLE_MKLDNN_V1
#include "mkldnn.hpp"
#else
#include "mkl_cblas.h"
#endif // ENABLE_MKLDNN_V1
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
@ -58,11 +62,11 @@ class MklMatMulOp : public OpKernel {
dim_pair[0].first = transpose_a_ ? 0 : 1;
dim_pair[0].second = transpose_b_ ? 1 : 0;
OP_REQUIRES(
ctx, a.dim_size(dim_pair[0].first) == b.dim_size(dim_pair[0].second),
errors::InvalidArgument(
"Matrix size-incompatible: In[0]: ", a.shape().DebugString(),
", In[1]: ", b.shape().DebugString()));
OP_REQUIRES(ctx,
a.dim_size(dim_pair[0].first) == b.dim_size(dim_pair[0].second),
errors::InvalidArgument("Matrix size-incompatible: In[0]: ",
a.shape().DebugString(), ", In[1]: ",
b.shape().DebugString()));
int a_dim_remaining = 1 - dim_pair[0].first;
int b_dim_remaining = 1 - dim_pair[0].second;
TensorShape out_shape(
@ -151,11 +155,18 @@ class MklMatMulOp : public OpKernel {
// 1.0 and 0.0 respectively.
const float alpha = 1.0f;
const float beta = 0.0f;
#ifdef ENABLE_MKLDNN_V1
char char_transa = transa ? 'T' : 'N';
char char_transb = transb ? 'T' : 'N';
VLOG(2) << "MKL DNN SGEMM CALLED";
dnnl_sgemm(char_transa, char_transb, m, n, k, alpha, a, lda, b, ldb, beta,
c, ldc);
#else
// TODO(intel-tf): Remove this after TF2.3 fork.
cblas_sgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans,
transb ? CblasTrans : CblasNoTrans, m, n, k, alpha, a, lda, b,
ldb, beta, c, ldc);
#endif // ENABLE_MKLDNN_V1
}
#ifdef ENABLE_INTEL_MKL_BFLOAT16