diff --git a/tensorflow/core/common_runtime/mkl_layout_pass.cc b/tensorflow/core/common_runtime/mkl_layout_pass.cc index fbec7059743..b3e7262f72f 100644 --- a/tensorflow/core/common_runtime/mkl_layout_pass.cc +++ b/tensorflow/core/common_runtime/mkl_layout_pass.cc @@ -394,10 +394,10 @@ class MklLayoutRewritePass : public GraphOptimizationPass { kRewriteForLayoutPropagation}); rinfo_.push_back({csinfo_.batch_matmul, mkl_op_registry::GetMklOpName(csinfo_.batch_matmul), - CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange}); + CopyAttrsAll, MatMulRewrite, kRewriteForOpNameChange}); rinfo_.push_back({csinfo_.batch_matmul_v2, mkl_op_registry::GetMklOpName(csinfo_.batch_matmul_v2), - CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange}); + CopyAttrsAll, MatMulRewrite, kRewriteForOpNameChange}); rinfo_.push_back( {csinfo_.concat, mkl_op_registry::GetMklOpName(csinfo_.concat), CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation}); diff --git a/tensorflow/core/kernels/mkl_batch_matmul_op.cc b/tensorflow/core/kernels/mkl_batch_matmul_op.cc index 2d0b18edb27..87e6002d9cb 100644 --- a/tensorflow/core/kernels/mkl_batch_matmul_op.cc +++ b/tensorflow/core/kernels/mkl_batch_matmul_op.cc @@ -25,7 +25,7 @@ limitations under the License. #define EIGEN_USE_THREADS -#if defined(INTEL_MKL) && !defined(INTEL_MKL_DNN_ONLY) +#if defined(INTEL_MKL) #include #include "mkl_cblas.h" @@ -54,7 +54,8 @@ typedef Eigen::ThreadPoolDevice CPUDevice; template class BatchMatMulMkl : public OpKernel { public: - explicit BatchMatMulMkl(OpKernelConstruction* context) : OpKernel(context) { + explicit BatchMatMulMkl(OpKernelConstruction* context) + : OpKernel(context), eigen_batch_mm_v2_(context) { OP_REQUIRES_OK(context, context->GetAttr("adj_x", &adj_x_)); OP_REQUIRES_OK(context, context->GetAttr("adj_y", &adj_y_)); } @@ -104,6 +105,14 @@ class BatchMatMulMkl : public OpKernel { "In[0] and In[1] must have compatible batch dimensions: ", lhs.shape().DebugString(), " vs. ", rhs.shape().DebugString())); +#ifdef ENABLE_MKLDNN_THREADPOOL + if (bcast.IsBroadcastingRequired()) { + // Calling Eigen Kernel for broadcasting case and return. Eigen does + // not have BF16 support, so we have to fail graciously in that case. + eigen_batch_mm_v2_.Compute(ctx); + return; + } +#endif // ENABLE_MKLDNN_THREADPOOL TensorShape out_shape = bcast.output_batch_shape(); auto batch_size = bcast.output_batch_size(); @@ -149,22 +158,27 @@ class BatchMatMulMkl : public OpKernel { std::vector ldc_array(batch_size, N); std::vector group_size(1, batch_size); - if (std::is_same::value) { + bool threadpool_enabled = false; +#ifdef ENABLE_MKLDNN_THREADPOOL + threadpool_enabled = true; +#endif // ENABLE_MKLDNN_THREADPOOL + if (std::is_same::value || threadpool_enabled) { // DNNL bfloat16 API requires a, b, and c as pointers to tensors // represented as flat-byte array. const Scalar* a = nullptr; const Scalar* b = nullptr; - OP_REQUIRES(ctx, !bcast.IsBroadcastingRequired(), - errors::Unimplemented("Broadcasting is not supported for " - "BFloat16 _MklBatchMatMul yet.")); + Scalar* c = nullptr; a = &lhs_reshaped(0, 0, 0); b = &rhs_reshaped(0, 0, 0); - Scalar* c = &out_reshaped(0, 0, 0); + OP_REQUIRES(ctx, !bcast.IsBroadcastingRequired(), + errors::Unimplemented("Broadcasting is not supported for " + "_MklBatchMatMul yet.")); + c = &out_reshaped(0, 0, 0); // TODO(nhasabni): Use appropriate cast instead of passing addresses of // a,b and c. MklCblasGemmBatch(CblasRowMajor, adj_x_, adj_y_, m_array, n_array, k_array, &a, lda_array, &b, ldb_array, &c, ldc_array, 1, - group_size); + group_size, ctx); } else { std::vector a_array; std::vector b_array; @@ -196,86 +210,48 @@ class BatchMatMulMkl : public OpKernel { // pointer is to 2D matrix. MklCblasGemmBatch(CblasRowMajor, adj_x_, adj_y_, m_array, n_array, k_array, &a_array[0], lda_array, &b_array[0], ldb_array, - &c_array[0], ldc_array, 1, group_size); + &c_array[0], ldc_array, 1, group_size, ctx); } } private: bool adj_x_; bool adj_y_; + BatchMatMulV2Op eigen_batch_mm_v2_; - template ::value || - std::is_same::value), - int>::type = 0> - void MklCblasGemmBatch(const CBLAS_LAYOUT Layout, const bool TransA, - const bool TransB, const std::vector& M_Array, - const std::vector& N_Array, - const std::vector& K_Array, const T** A_Array, - const std::vector& lda_Array, - const T** B_Array, - const std::vector& ldb_Array, T** C_Array, - const std::vector& ldc_Array, - const MKL_INT group_count, - const std::vector& group_size) { + void MklCblasGemmBatch( + const CBLAS_LAYOUT Layout, const bool TransA, const bool TransB, + const std::vector& M_Array, const std::vector& N_Array, + const std::vector& K_Array, const float** A_Array, + const std::vector& lda_Array, const float** B_Array, + const std::vector& ldb_Array, float** C_Array, + const std::vector& ldc_Array, const MKL_INT group_count, + const std::vector& group_size, OpKernelContext* ctx) { +#ifndef ENABLE_MKLDNN_THREADPOOL std::vector TransA_Array( group_size[0], TransA ? CblasTrans : CblasNoTrans); std::vector TransB_Array( group_size[0], TransB ? CblasTrans : CblasNoTrans); - if (std::is_same::value) { - std::vector alpha_Array(group_size[0], 1.0); - std::vector beta_Array(group_size[0], 0.0); - cblas_sgemm_batch(Layout, &TransA_Array[0], &TransB_Array[0], &M_Array[0], - &N_Array[0], &K_Array[0], &alpha_Array[0], - reinterpret_cast(A_Array), &lda_Array[0], - reinterpret_cast(B_Array), &ldb_Array[0], - &beta_Array[0], reinterpret_cast(C_Array), - &ldc_Array[0], group_count, &group_size[0]); - } else { - std::vector alpha_Array(group_size[0], 1.0); - std::vector beta_Array(group_size[0], 0.0); - cblas_dgemm_batch( - Layout, &TransA_Array[0], &TransB_Array[0], &M_Array[0], &N_Array[0], - &K_Array[0], &alpha_Array[0], - reinterpret_cast(A_Array), &lda_Array[0], - reinterpret_cast(B_Array), &ldb_Array[0], - &beta_Array[0], reinterpret_cast(C_Array), &ldc_Array[0], - group_count, &group_size[0]); - } + std::vector alpha_Array(group_size[0], 1.0); + std::vector beta_Array(group_size[0], 0.0); + cblas_sgemm_batch(Layout, &TransA_Array[0], &TransB_Array[0], &M_Array[0], + &N_Array[0], &K_Array[0], &alpha_Array[0], + reinterpret_cast(A_Array), &lda_Array[0], + reinterpret_cast(B_Array), &ldb_Array[0], + &beta_Array[0], reinterpret_cast(C_Array), + &ldc_Array[0], group_count, &group_size[0]); +#else + DCHECK(Layout == CblasRowMajor); + std::vector TransA_Array(group_size[0], TransA); + std::vector TransB_Array(group_size[0], TransB); + std::vector alpha_Array(group_size[0], 1.0); + std::vector beta_Array(group_size[0], 0.0); + dnnl_gemm_batch(TransA_Array, TransB_Array, M_Array, N_Array, + K_Array, alpha_Array, *A_Array, *B_Array, beta_Array, + *C_Array, group_count, group_size, ctx); +#endif // !ENABLE_MKLDNN_THREADPOOL } - - template ::value || - std::is_same::value), - int>::type = 0> - void MklCblasGemmBatch(const CBLAS_LAYOUT Layout, const bool TransA, - const bool TransB, const std::vector& M_Array, - const std::vector& N_Array, - const std::vector& K_Array, const T** A_Array, - const std::vector& lda_Array, - const T** B_Array, - const std::vector& ldb_Array, T** C_Array, - const std::vector& ldc_Array, - const MKL_INT group_count, - const std::vector& group_size) { - std::vector TransA_array( - group_size[0], TransA ? CblasConjTrans : CblasNoTrans); - std::vector TransB_array( - group_size[0], TransB ? CblasConjTrans : CblasNoTrans); - std::vector alpha_Array(group_size[0], {1.0f, 0.0f}); - std::vector beta_Array(group_size[0], {0.0f, 0.0f}); - auto gemm_fn = (std::is_same::value) ? cblas_cgemm_batch - : cblas_zgemm_batch; - gemm_fn(Layout, &TransA_array[0], &TransB_array[0], &M_Array[0], - &N_Array[0], &K_Array[0], static_cast(&alpha_Array[0]), - reinterpret_cast(A_Array), &lda_Array[0], - reinterpret_cast(B_Array), &ldb_Array[0], - static_cast(&beta_Array[0]), - reinterpret_cast(C_Array), &ldc_Array[0], group_count, - &group_size[0]); - } - - // BatchMatMul BFloat16 support only exists in DNNL 1.2 onwards. +// BatchMatMul BFloat16 support only exists in DNNL 1.2 onwards. #if defined(ENABLE_MKLDNN_V1) && defined(ENABLE_INTEL_MKL_BFLOAT16) void MklCblasGemmBatch( const CBLAS_LAYOUT Layout, const bool TransA, const bool TransB, @@ -284,7 +260,7 @@ class BatchMatMulMkl : public OpKernel { const std::vector& lda_Array, const bfloat16** B_Array, const std::vector& ldb_Array, bfloat16** C_Array, const std::vector& ldc_Array, const MKL_INT group_count, - const std::vector& group_size) { + const std::vector& group_size, OpKernelContext* ctx) { DCHECK(Layout == CblasRowMajor); std::vector TransA_Array(group_size[0], TransA); std::vector TransB_Array(group_size[0], TransB); @@ -292,9 +268,9 @@ class BatchMatMulMkl : public OpKernel { std::vector beta_Array(group_size[0], 0.0); // TODO(nhasabni): Remove *A when we pass a, b, and c correctly. // MKLDNN API does not require lda, ldb, and ldc. - dnnl_gemm_batch(TransA_Array, TransB_Array, M_Array, N_Array, - K_Array, alpha_Array, *A_Array, *B_Array, - beta_Array, *C_Array, group_count, group_size); + dnnl_gemm_batch( + TransA_Array, TransB_Array, M_Array, N_Array, K_Array, alpha_Array, + *A_Array, *B_Array, beta_Array, *C_Array, group_count, group_size, ctx); } #endif // ENABLE_MKLDNN_V1 && ENABLE_INTEL_MKL_BFLOAT16 }; @@ -315,13 +291,7 @@ class BatchMatMulMkl : public OpKernel { #ifdef ENABLE_MKL TF_CALL_float(REGISTER_BATCH_MATMUL_MKL); -TF_CALL_double(REGISTER_BATCH_MATMUL_MKL); -TF_CALL_COMPLEX_TYPES(REGISTER_BATCH_MATMUL_MKL); - TF_CALL_float(REGISTER_BATCH_MATMUL_MKL_V2); -TF_CALL_double(REGISTER_BATCH_MATMUL_MKL_V2); -TF_CALL_COMPLEX_TYPES(REGISTER_BATCH_MATMUL_MKL_V2); - #if defined(ENABLE_MKLDNN_V1) && defined(ENABLE_INTEL_MKL_BFLOAT16) TF_CALL_bfloat16(REGISTER_BATCH_MATMUL_MKL); TF_CALL_bfloat16(REGISTER_BATCH_MATMUL_MKL_V2); diff --git a/tensorflow/core/kernels/mkl_matmul_ops_common.h b/tensorflow/core/kernels/mkl_matmul_ops_common.h index d3a05a4a6d2..490afd55932 100644 --- a/tensorflow/core/kernels/mkl_matmul_ops_common.h +++ b/tensorflow/core/kernels/mkl_matmul_ops_common.h @@ -707,8 +707,8 @@ void dnnl_gemm_batch(const std::vector& transa, const std::vector& n, const std::vector& k, const std::vector& alpha, const T* a, const T* b, const std::vector& beta, T* c, - const int group_count, - const std::vector& group_size) { + const int group_count, const std::vector& group_size, + OpKernelContext* ctx = nullptr) { // Current BatchMatMul support in Tensorflow is narrower than the one offered // by MKL and MKL-DNN. Current BatchMatMul support in Tensorflow uses only 1 // group of size equal to batch_size, and all MatMul parameters (m, n, k, diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index dfc2463915c..cc20bc7b4d6 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -142,9 +142,7 @@ REGISTER_OP("_MklBatchMatMul") .Input("x: T") .Input("y: T") .Output("output: T") - .Attr( - "T: {bfloat16, half, float, double, int32, int64, complex64, " - "complex128}") + .Attr("T: {bfloat16, float}") .Attr("adj_x: bool = false") .Attr("adj_y: bool = false") .SetShapeFn(shape_inference::BatchMatMulShape); @@ -153,9 +151,7 @@ REGISTER_OP("_MklBatchMatMulV2") .Input("x: T") .Input("y: T") .Output("output: T") - .Attr( - "T: {bfloat16, half, float, double, int32, int64, complex64, " - "complex128}") + .Attr("T: {bfloat16, float}") .Attr("adj_x: bool = false") .Attr("adj_y: bool = false") .SetShapeFn(shape_inference::BatchMatMulV2Shape);