Merge pull request #39325 from Intel-tensorflow:sriniva2/dnnl_sgemm

PiperOrigin-RevId: 312846725
Change-Id: I8204aea5531380ab1f916c4053f4049d01a788d5
This commit is contained in:
TensorFlower Gardener 2020-05-22 12:04:12 -07:00
commit 6333fef206
3 changed files with 18 additions and 78 deletions

View File

@ -506,7 +506,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
CopyAttrsAll, LrnGradRewrite, kRewriteForLayoutPropagation}); CopyAttrsAll, LrnGradRewrite, kRewriteForLayoutPropagation});
rinfo_.push_back({csinfo_.matmul, rinfo_.push_back({csinfo_.matmul,
mkl_op_registry::GetMklOpName(csinfo_.matmul), mkl_op_registry::GetMklOpName(csinfo_.matmul),
CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange}); CopyAttrsAll, MatMulRewrite, kRewriteForOpNameChange});
rinfo_.push_back( rinfo_.push_back(
{csinfo_.leakyrelu, mkl_op_registry::GetMklOpName(csinfo_.leakyrelu), {csinfo_.leakyrelu, mkl_op_registry::GetMklOpName(csinfo_.leakyrelu),
CopyAttrsAll, LeakyReluRewrite, kRewriteForLayoutPropagation}); CopyAttrsAll, LeakyReluRewrite, kRewriteForLayoutPropagation});
@ -1482,6 +1482,16 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
return false; return false;
} }
static bool MatMulRewrite(const Node* n) {
DataType T;
GetNodeAttr(n->def(), "T", &T);
if ((T == DT_FLOAT) || (T == DT_BFLOAT16)) {
VLOG(2) << "Rewriting MatMul to _MklMatMul";
return true;
}
return false;
}
static bool DequantizeRewrite(const Node* n) { static bool DequantizeRewrite(const Node* n) {
DCHECK(n); DCHECK(n);
Node* input = nullptr; Node* input = nullptr;

View File

@ -25,6 +25,7 @@ limitations under the License.
#if defined(INTEL_MKL) #if defined(INTEL_MKL)
#include "mkldnn.hpp"
#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/register_types.h"
@ -32,13 +33,6 @@ limitations under the License.
#include "tensorflow/core/kernels/mkl_matmul_ops_common.h" #include "tensorflow/core/kernels/mkl_matmul_ops_common.h"
#include "tensorflow/core/util/mkl_util.h" #include "tensorflow/core/util/mkl_util.h"
// This header file is part of MKL ML, need equivalent file in MKL DNN
#ifndef INTEL_MKL_DNN_ONLY
#include "mkl_cblas.h"
#endif
#include "mkldnn.h"
namespace tensorflow { namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::ThreadPoolDevice CPUDevice;
@ -157,21 +151,11 @@ class MklMatMulOp : public OpKernel {
// 1.0 and 0.0 respectively. // 1.0 and 0.0 respectively.
const float alpha = 1.0f; const float alpha = 1.0f;
const float beta = 0.0f; const float beta = 0.0f;
#if defined(INTEL_MKL_DNN_ONLY) char char_transa = transa ? 'T' : 'N';
const char* const ftrans[] = {"N", "T", "C"}; char char_transb = transb ? 'T' : 'N';
int index_transa = transa ? 1 : 0; VLOG(2) << "MKL DNN SGEMM CALLED";
int index_transb = transb ? 1 : 0; dnnl_sgemm(char_transa, char_transb, m, n, k, alpha, a, lda, b, ldb, beta,
VLOG(2) << "MKL DNN SGEMM called"; c, ldc);
// MKL DNN only supports the Fortran api and requires column major while
// Tensorflow uses row major so we reverse the order A and B
mkldnn_sgemm(ftrans[index_transb], ftrans[index_transa], &n, &m, &k, &alpha,
b, &ldb, a, &lda, &beta, c, &ldc);
#else
// MKL ML binary uses CBLAS API
cblas_sgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans,
transb ? CblasTrans : CblasNoTrans, m, n, k, alpha, a, lda, b,
ldb, beta, c, ldc);
#endif
} }
#ifdef ENABLE_INTEL_MKL_BFLOAT16 #ifdef ENABLE_INTEL_MKL_BFLOAT16
@ -205,53 +189,6 @@ class MklMatMulOp : public OpKernel {
FloatToBFloat16(c_float.flat<float>().data(), c, c_float.NumElements()); FloatToBFloat16(c_float.flat<float>().data(), c, c_float.NumElements());
} }
#endif // ENABLE_INTEL_MKL_BFLOAT16 #endif // ENABLE_INTEL_MKL_BFLOAT16
// MKL-DNN only supports SGEMM and bfloat16-GEMM.
#ifndef INTEL_MKL_DNN_ONLY
// Matrix-Matrix Multiplication with FP64 tensors. For detailed info about
// parameters, look at FP32 function description.
void MklBlasGemm(OpKernelContext* ctx, bool transa, bool transb, const int m,
const int n, const int k, const double* a, const int lda,
const double* b, const int ldb, double* c, const int ldc) {
const double alpha = 1.0;
const double beta = 0.0;
cblas_dgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans,
transb ? CblasTrans : CblasNoTrans, m, n, k, alpha, a, lda, b,
ldb, beta, c, ldc);
}
// Matrix-Matrix Multiplication with Complex64 (std::complex<float>) tensors.
// For detailed info about parameters, look at FP32 function description.
void MklBlasGemm(OpKernelContext* ctx, bool transa, bool transb, const int m,
const int n, const int k, const complex64* a, const int lda,
const complex64* b, const int ldb, complex64* c,
int const ldc) {
const MKL_Complex8 alpha = {1.0f, 0.0f};
const MKL_Complex8 beta = {0.0f, 0.0f};
cblas_cgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans,
transb ? CblasTrans : CblasNoTrans, m, n, k, &alpha,
reinterpret_cast<const MKL_Complex8*>(a), lda,
reinterpret_cast<const MKL_Complex8*>(b), ldb, &beta,
reinterpret_cast<MKL_Complex8*>(c), ldc);
}
// Matrix-Matrix Multiplication with Complex128 (std::complex<double>)
// tensors. For detailed info about parameters, look at FP32 function
// description.
void MklBlasGemm(OpKernelContext* ctx, bool transa, bool transb, const int m,
const int n, const int k, const complex128* a, const int lda,
const complex128* b, const int ldb, complex128* c,
const int ldc) {
const MKL_Complex16 alpha = {1.0, 0.0};
const MKL_Complex16 beta = {0.0, 0.0};
cblas_zgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans,
transb ? CblasTrans : CblasNoTrans, m, n, k, &alpha,
reinterpret_cast<const MKL_Complex16*>(a), lda,
reinterpret_cast<const MKL_Complex16*>(b), ldb, &beta,
reinterpret_cast<MKL_Complex16*>(c), ldc);
}
#endif // !INTEL_MKL_DNN_ONLY
}; };
#define REGISTER_CPU(T) \ #define REGISTER_CPU(T) \
@ -269,13 +206,6 @@ TF_CALL_float(REGISTER_CPU);
#ifdef ENABLE_INTEL_MKL_BFLOAT16 #ifdef ENABLE_INTEL_MKL_BFLOAT16
TF_CALL_bfloat16(REGISTER_CPU); TF_CALL_bfloat16(REGISTER_CPU);
#endif // ENABLE_INTEL_MKL_BFLOAT16 #endif // ENABLE_INTEL_MKL_BFLOAT16
#ifndef INTEL_MKL_DNN_ONLY
TF_CALL_double(REGISTER_CPU);
TF_CALL_complex64(REGISTER_CPU);
TF_CALL_complex128(REGISTER_CPU);
#endif // !INTEL_MKL_DNN_ONLY
#endif // ENABLE_MKL #endif // ENABLE_MKL
} // namespace tensorflow } // namespace tensorflow
#endif // INTEL_MKL #endif // INTEL_MKL

View File

@ -936,7 +936,7 @@ REGISTER_OP("_MklMatMul")
.Output("product: T") .Output("product: T")
.Attr("transpose_a: bool = false") .Attr("transpose_a: bool = false")
.Attr("transpose_b: bool = false") .Attr("transpose_b: bool = false")
.Attr("T: {bfloat16, float, double, complex64, complex128}") .Attr("T: {bfloat16, float}")
.SetShapeFn(shape_inference::MatMulShape); .SetShapeFn(shape_inference::MatMulShape);
#endif // INTEL_MKL #endif // INTEL_MKL