sgemm fix.
This commit is contained in:
parent
be46769cee
commit
19e210f87c
@ -25,7 +25,11 @@ limitations under the License.
|
||||
|
||||
#if defined(INTEL_MKL)
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
#include "mkldnn.hpp"
|
||||
#else
|
||||
#include "mkl_cblas.h"
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
@ -58,11 +62,11 @@ class MklMatMulOp : public OpKernel {
|
||||
dim_pair[0].first = transpose_a_ ? 0 : 1;
|
||||
dim_pair[0].second = transpose_b_ ? 1 : 0;
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, a.dim_size(dim_pair[0].first) == b.dim_size(dim_pair[0].second),
|
||||
errors::InvalidArgument(
|
||||
"Matrix size-incompatible: In[0]: ", a.shape().DebugString(),
|
||||
", In[1]: ", b.shape().DebugString()));
|
||||
OP_REQUIRES(ctx,
|
||||
a.dim_size(dim_pair[0].first) == b.dim_size(dim_pair[0].second),
|
||||
errors::InvalidArgument("Matrix size-incompatible: In[0]: ",
|
||||
a.shape().DebugString(), ", In[1]: ",
|
||||
b.shape().DebugString()));
|
||||
int a_dim_remaining = 1 - dim_pair[0].first;
|
||||
int b_dim_remaining = 1 - dim_pair[0].second;
|
||||
TensorShape out_shape(
|
||||
@ -151,11 +155,18 @@ class MklMatMulOp : public OpKernel {
|
||||
// 1.0 and 0.0 respectively.
|
||||
const float alpha = 1.0f;
|
||||
const float beta = 0.0f;
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
char char_transa = transa ? 'T' : 'N';
|
||||
char char_transb = transb ? 'T' : 'N';
|
||||
VLOG(2) << "MKL DNN SGEMM CALLED";
|
||||
dnnl_sgemm(char_transa, char_transb, m, n, k, alpha, a, lda, b, ldb, beta,
|
||||
c, ldc);
|
||||
#else
|
||||
// TODO(intel-tf): Remove this after TF2.3 fork.
|
||||
cblas_sgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans,
|
||||
transb ? CblasTrans : CblasNoTrans, m, n, k, alpha, a, lda, b,
|
||||
ldb, beta, c, ldc);
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
}
|
||||
|
||||
#ifdef ENABLE_INTEL_MKL_BFLOAT16
|
||||
|
Loading…
x
Reference in New Issue
Block a user