Enabling DNNL SGEMM and removing all code related to MKL matmuls.

This commit is contained in:
Srinivasan Narayanamoorthy 2020-05-08 13:10:15 -07:00
parent a8b9d64276
commit 738a28685b
3 changed files with 18 additions and 78 deletions

View File

@ -499,7 +499,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
CopyAttrsAll, LrnGradRewrite, kRewriteForLayoutPropagation});
rinfo_.push_back({csinfo_.matmul,
mkl_op_registry::GetMklOpName(csinfo_.matmul),
CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange});
CopyAttrsAll, MatMulRewrite, kRewriteForOpNameChange});
rinfo_.push_back(
{csinfo_.leakyrelu, mkl_op_registry::GetMklOpName(csinfo_.leakyrelu),
CopyAttrsAll, LeakyReluRewrite, kRewriteForLayoutPropagation});
@ -1473,6 +1473,16 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
return false;
}
static bool MatMulRewrite(const Node* n) {
DataType T;
GetNodeAttr(n->def(), "T", &T);
if ((T == DT_FLOAT) || (T == DT_BFLOAT16)) {
VLOG(2) << "Rewriting MatMul to _MklMatMul";
return true;
}
return false;
}
static bool DequantizeRewrite(const Node* n) {
DCHECK(n);
Node* input = nullptr;

View File

@ -31,13 +31,7 @@ limitations under the License.
#include "tensorflow/core/kernels/fill_functor.h"
#include "tensorflow/core/kernels/mkl_matmul_ops_common.h"
#include "tensorflow/core/util/mkl_util.h"
// This header file is part of MKL ML, need equivalent file in MKL DNN
#ifndef INTEL_MKL_DNN_ONLY
#include "mkl_cblas.h"
#endif
#include "mkldnn.h"
#include "mkldnn.hpp"
namespace tensorflow {
@ -157,21 +151,11 @@ class MklMatMulOp : public OpKernel {
// 1.0 and 0.0 respectively.
const float alpha = 1.0f;
const float beta = 0.0f;
#if defined(INTEL_MKL_DNN_ONLY)
const char* const ftrans[] = {"N", "T", "C"};
int index_transa = transa ? 1 : 0;
int index_transb = transb ? 1 : 0;
VLOG(2) << "MKL DNN SGEMM called";
// MKL DNN only supports the Fortran api and requires column major while
// Tensorflow uses row major so we reverse the order A and B
mkldnn_sgemm(ftrans[index_transb], ftrans[index_transa], &n, &m, &k, &alpha,
b, &ldb, a, &lda, &beta, c, &ldc);
#else
// MKL ML binary uses CBLAS API
cblas_sgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans,
transb ? CblasTrans : CblasNoTrans, m, n, k, alpha, a, lda, b,
ldb, beta, c, ldc);
#endif
char char_transa = transa ? 'T' : 'N';
char char_transb = transb ? 'T' : 'N';
VLOG(2) << "MKL DNN SGEMM CALLED";
dnnl_sgemm(char_transa, char_transb, m, n, k, alpha,
a, lda, b, ldb, beta, c, ldc);
}
#ifdef ENABLE_INTEL_MKL_BFLOAT16
@ -205,53 +189,6 @@ class MklMatMulOp : public OpKernel {
FloatToBFloat16(c_float.flat<float>().data(), c, c_float.NumElements());
}
#endif // ENABLE_INTEL_MKL_BFLOAT16
// MKL-DNN only supports SGEMM and bfloat16-GEMM.
#ifndef INTEL_MKL_DNN_ONLY
// Matrix-Matrix Multiplication with FP64 tensors. For detailed info about
// parameters, look at FP32 function description.
void MklBlasGemm(OpKernelContext* ctx, bool transa, bool transb, const int m,
const int n, const int k, const double* a, const int lda,
const double* b, const int ldb, double* c, const int ldc) {
const double alpha = 1.0;
const double beta = 0.0;
cblas_dgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans,
transb ? CblasTrans : CblasNoTrans, m, n, k, alpha, a, lda, b,
ldb, beta, c, ldc);
}
// Matrix-Matrix Multiplication with Complex64 (std::complex<float>) tensors.
// For detailed info about parameters, look at FP32 function description.
void MklBlasGemm(OpKernelContext* ctx, bool transa, bool transb, const int m,
const int n, const int k, const complex64* a, const int lda,
const complex64* b, const int ldb, complex64* c,
int const ldc) {
const MKL_Complex8 alpha = {1.0f, 0.0f};
const MKL_Complex8 beta = {0.0f, 0.0f};
cblas_cgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans,
transb ? CblasTrans : CblasNoTrans, m, n, k, &alpha,
reinterpret_cast<const MKL_Complex8*>(a), lda,
reinterpret_cast<const MKL_Complex8*>(b), ldb, &beta,
reinterpret_cast<MKL_Complex8*>(c), ldc);
}
// Matrix-Matrix Multiplication with Complex128 (std::complex<double>)
// tensors. For detailed info about parameters, look at FP32 function
// description.
void MklBlasGemm(OpKernelContext* ctx, bool transa, bool transb, const int m,
const int n, const int k, const complex128* a, const int lda,
const complex128* b, const int ldb, complex128* c,
const int ldc) {
const MKL_Complex16 alpha = {1.0, 0.0};
const MKL_Complex16 beta = {0.0, 0.0};
cblas_zgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans,
transb ? CblasTrans : CblasNoTrans, m, n, k, &alpha,
reinterpret_cast<const MKL_Complex16*>(a), lda,
reinterpret_cast<const MKL_Complex16*>(b), ldb, &beta,
reinterpret_cast<MKL_Complex16*>(c), ldc);
}
#endif // !INTEL_MKL_DNN_ONLY
};
#define REGISTER_CPU(T) \
@ -269,13 +206,6 @@ TF_CALL_float(REGISTER_CPU);
#ifdef ENABLE_INTEL_MKL_BFLOAT16
TF_CALL_bfloat16(REGISTER_CPU);
#endif // ENABLE_INTEL_MKL_BFLOAT16
#ifndef INTEL_MKL_DNN_ONLY
TF_CALL_double(REGISTER_CPU);
TF_CALL_complex64(REGISTER_CPU);
TF_CALL_complex128(REGISTER_CPU);
#endif // !INTEL_MKL_DNN_ONLY
#endif // ENABLE_MKL
} // namespace tensorflow
#endif // INTEL_MKL

View File

@ -936,7 +936,7 @@ REGISTER_OP("_MklMatMul")
.Output("product: T")
.Attr("transpose_a: bool = false")
.Attr("transpose_b: bool = false")
.Attr("T: {bfloat16, float, double, complex64, complex128}")
.Attr("T: {bfloat16, float}")
.SetShapeFn(shape_inference::MatMulShape);
#endif // INTEL_MKL