diff --git a/tensorflow/core/kernels/mkl_matmul_op.cc b/tensorflow/core/kernels/mkl_matmul_op.cc index 86193901c96..fb197618fb0 100644 --- a/tensorflow/core/kernels/mkl_matmul_op.cc +++ b/tensorflow/core/kernels/mkl_matmul_op.cc @@ -62,11 +62,11 @@ class MklMatMulOp : public OpKernel { dim_pair[0].first = transpose_a_ ? 0 : 1; dim_pair[0].second = transpose_b_ ? 1 : 0; - OP_REQUIRES( - ctx, a.dim_size(dim_pair[0].first) == b.dim_size(dim_pair[0].second), - errors::InvalidArgument( - "Matrix size-incompatible: In[0]: ", a.shape().DebugString(), - ", In[1]: ", b.shape().DebugString())); + OP_REQUIRES(ctx, + a.dim_size(dim_pair[0].first) == b.dim_size(dim_pair[0].second), + errors::InvalidArgument("Matrix size-incompatible: In[0]: ", + a.shape().DebugString(), ", In[1]: ", + b.shape().DebugString())); int a_dim_remaining = 1 - dim_pair[0].first; int b_dim_remaining = 1 - dim_pair[0].second; TensorShape out_shape( @@ -158,9 +158,17 @@ class MklMatMulOp : public OpKernel { #ifdef ENABLE_MKLDNN_V1 char char_transa = transa ? 'T' : 'N'; char char_transb = transb ? 'T' : 'N'; - VLOG(2) << "MKL DNN SGEMM CALLED"; + VLOG(2) << "MKL DNN SGEMM called"; +#ifdef ENABLE_MKLDNN_THREADPOOL + auto eigen_tp = + MklDnnThreadPoolWrapper::GetInstance().CreateThreadPoolPtr(ctx); + + dnnl_sgemm_tp(char_transa, char_transb, m, n, k, alpha, a, lda, b, ldb, + beta, c, ldc, eigen_tp); +#else dnnl_sgemm(char_transa, char_transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +#endif // ENABLE_MKLDNN_THREADPOOL #else // TODO(intel-tf): Remove this after TF2.3 fork. cblas_sgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans, @@ -182,7 +190,7 @@ class MklMatMulOp : public OpKernel { #ifdef ENABLE_MKLDNN_V1 const char ftrans[] = {'N', 'T', 'C'}; dnnl_gemm<bfloat16>(ftrans[index_transa], ftrans[index_transb], m, n, k, - alpha, a, lda, b, ldb, beta, c, ldc); + alpha, a, lda, b, ldb, beta, c, ldc, ctx); #else Tensor c_float; OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, {m, n}, &c_float)); diff --git a/tensorflow/core/kernels/mkl_matmul_op_fused.cc b/tensorflow/core/kernels/mkl_matmul_op_fused.cc index 99a2cfc214b..f3608ef72a8 100644 --- a/tensorflow/core/kernels/mkl_matmul_op_fused.cc +++ b/tensorflow/core/kernels/mkl_matmul_op_fused.cc @@ -86,11 +86,10 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> { const int k = src_tf_shape.dim_size(dim_pair[0]); const int channel = weight_tf_shape.dim_size(1 - dim_pair[1]); - OP_REQUIRES( - ctx, k == weight_tf_shape.dim_size(dim_pair[1]), - errors::InvalidArgument( - "Matrix size-incompatible: In[0]: ", src_tf_shape.DebugString(), - ", In[1]: ", weight_tf_shape.DebugString())); + OP_REQUIRES(ctx, k == weight_tf_shape.dim_size(dim_pair[1]), + errors::InvalidArgument("Matrix size-incompatible: In[0]: ", + src_tf_shape.DebugString(), ", In[1]: ", + weight_tf_shape.DebugString())); OP_REQUIRES(ctx, bias_tensor.shape().dim_size(0) == channel, errors::InvalidArgument( "Must provide as many biases as the channel size: ", @@ -159,8 +158,10 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> { if (IS_SRC_REORDER_NEEDED(src_md, matmul_pd, matmul_prim)) { src_mkl.SetUsrMem(src_md, src_data); - src_mkl.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA( - matmul_pd.get()->PRIMITIVE_DESC_SRC, this->cpu_engine_)); + src_mkl.CheckReorderToOpMem( + MEMORY_PD_WITHOUT_DATA(matmul_pd.get()->PRIMITIVE_DESC_SRC, + this->cpu_engine_), + ctx); src_data = reinterpret_cast<T*>(src_mkl.GetOpMem().get_data_handle()); } @@ -191,19 +192,23 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> { weight_data = cached_weight_data; } else { weight_mkl.SetUsrMem(weight_md, weight_data); - weight_mkl.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA( - matmul_pd.get()->PRIMITIVE_DESC_WEIGHTS, this->cpu_engine_)); + weight_mkl.CheckReorderToOpMem( + MEMORY_PD_WITHOUT_DATA(matmul_pd.get()->PRIMITIVE_DESC_WEIGHTS, + this->cpu_engine_), + ctx); weight_data = reinterpret_cast<T*>(weight_mkl.GetOpMem().get_data_handle()); } } - + std::shared_ptr<stream> cpu_stream; + cpu_stream.reset(CreateStream(ctx, matmul_prim->GetEngine())); // Execute fused matmul op. - matmul_prim->Execute(src_data, weight_data, bias_data, dst_data); + matmul_prim->Execute(src_data, weight_data, bias_data, dst_data, + cpu_stream); } catch (mkldnn::error& e) { - string error_msg = "Status: " + std::to_string(e.status) + - ", message: " + string(e.message) + ", in file " + - string(__FILE__) + ":" + std::to_string(__LINE__); + string error_msg = "Status: " + std::to_string(e.status) + ", message: " + + string(e.message) + ", in file " + string(__FILE__) + + ":" + std::to_string(__LINE__); OP_REQUIRES_OK( ctx, errors::Aborted("Operation received an exception:", error_msg)); } diff --git a/tensorflow/core/kernels/mkl_matmul_ops_common.h b/tensorflow/core/kernels/mkl_matmul_ops_common.h index d3a05a4a6d2..d7af614ad04 100644 --- a/tensorflow/core/kernels/mkl_matmul_ops_common.h +++ b/tensorflow/core/kernels/mkl_matmul_ops_common.h @@ -75,8 +75,7 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive { public: explicit MklDnnMatMulFwdPrimitive( const MklDnnMatMulFwdParams& matmulFwdParams) - : cpu_engine_(ENGINE_CPU, 0) { - context_.fwd_stream.reset(new CPU_STREAM(cpu_engine_)); + : MklPrimitive(engine(ENGINE_CPU, 0)) { // Create matmul primitive if (context_.matmul_fwd == nullptr) { Setup(matmulFwdParams); @@ -91,7 +90,18 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive { // - bias_data: input data buffer of bias // - dst_data: output data buffer of dst void Execute(const Tinput* src_data, const Tweight* weight_data, - const Tbias* bias_data, Toutput* dst_data) { + const Tbias* bias_data, Toutput* dst_data, + std::shared_ptr<stream> fwd_stream) { +#ifdef ENABLE_MKLDNN_THREADPOOL + context_.src_mem->set_data_handle( + static_cast<void*>(const_cast<Tinput*>(src_data)), *fwd_stream); + context_.weight_mem->set_data_handle( + static_cast<void*>(const_cast<Tweight*>(weight_data)), *fwd_stream); + context_.bias_mem->set_data_handle( + static_cast<void*>(const_cast<Tbias*>(bias_data))); + context_.dst_mem->set_data_handle(static_cast<void*>(dst_data), + *fwd_stream); +#else context_.src_mem->set_data_handle( static_cast<void*>(const_cast<Tinput*>(src_data))); context_.weight_mem->set_data_handle( @@ -99,12 +109,12 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive { context_.bias_mem->set_data_handle( static_cast<void*>(const_cast<Tbias*>(bias_data))); context_.dst_mem->set_data_handle(static_cast<void*>(dst_data)); +#endif // ENABLE_MKLDNN_THREADPOOL #ifdef ENABLE_MKLDNN_V1 - execute_primitives(context_.fwd_primitives, context_.fwd_stream, - context_.net_args); + execute_primitives(context_.fwd_primitives, fwd_stream, context_.net_args); #else - context_.fwd_stream->submit(context_.fwd_primitives); + fwd_stream->submit(context_.fwd_primitives); #endif // ENABLE_MKLDNN_V1 // After execution, set data handle back @@ -153,7 +163,6 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive { // Inner-product primitive. std::shared_ptr<mkldnn::primitive> matmul_fwd; - std::shared_ptr<mkldnn::stream> fwd_stream; std::vector<mkldnn::primitive> fwd_primitives; #ifdef ENABLE_MKLDNN_V1 @@ -176,8 +185,7 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive { weight_md(nullptr), bias_md(nullptr), dst_md(nullptr), - matmul_fwd(nullptr), - fwd_stream(nullptr) { + matmul_fwd(nullptr) { } }; @@ -292,7 +300,6 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive { } struct MklDnnMatMulFwdContext context_; - engine cpu_engine_; }; template <typename T, typename Tinput, typename Tweight, typename Tbias, @@ -439,8 +446,10 @@ class MklDnnMatMulOpBase : public OpKernel { // reorder and cache the weight weight.SetUsrMem(weight_md, &weight_tensor); - weight.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA( - matmul_fwd_pd.get()->PRIMITIVE_DESC_WEIGHTS, cpu_engine_)); + weight.CheckReorderToOpMem( + MEMORY_PD_WITHOUT_DATA(matmul_fwd_pd.get()->PRIMITIVE_DESC_WEIGHTS, + cpu_engine_), + context); weight_data = static_cast<Tweight*>(weight.GetOpMem().get_data_handle()); Tensor* weight_tensor_ptr = nullptr; @@ -544,21 +553,28 @@ template <typename T> class MklMatMulPrimitive : public MklPrimitive { public: explicit MklMatMulPrimitive(const MklMatMulParams& params) - : cpu_engine_(ENGINE_CPU, 0) { - context_.stream.reset(new CPU_STREAM(cpu_engine_)); + : MklPrimitive(engine(ENGINE_CPU, 0)) { // Create matmul primitive Setup(params); } ~MklMatMulPrimitive() {} - void Execute(const T* a_data, const T* b_data, T* c_data) { + void Execute(const T* a_data, const T* b_data, T* c_data, + std::shared_ptr<stream> stream) { +#ifdef ENABLE_MKLDNN_THREADPOOL + context_.a_mem->set_data_handle(static_cast<void*>(const_cast<T*>(a_data)), + *stream); + context_.b_mem->set_data_handle(static_cast<void*>(const_cast<T*>(b_data)), + *stream); + context_.c_mem->set_data_handle(static_cast<void*>(const_cast<T*>(c_data)), + *stream); +#else context_.a_mem->set_data_handle(static_cast<void*>(const_cast<T*>(a_data))); context_.b_mem->set_data_handle(static_cast<void*>(const_cast<T*>(b_data))); context_.c_mem->set_data_handle(static_cast<void*>(const_cast<T*>(c_data))); - - execute_primitives(context_.matmul_primitives, context_.stream, - context_.net_args); +#endif // ENABLE_MKLDNN_THREADPOOL + execute_primitives(context_.matmul_primitives, stream, context_.net_args); // After execution, set data handle back context_.a_mem->set_data_handle(DummyData); @@ -584,7 +600,6 @@ class MklMatMulPrimitive : public MklPrimitive { std::shared_ptr<mkldnn::memory::desc> c_md; // MatMul primitive. - std::shared_ptr<mkldnn::stream> stream; std::vector<mkldnn::primitive> matmul_primitives; std::vector<std::unordered_map<int, memory>> net_args; @@ -596,8 +611,7 @@ class MklMatMulPrimitive : public MklPrimitive { prim_desc(nullptr), a_md(nullptr), b_md(nullptr), - c_md(nullptr), - stream(nullptr) {} + c_md(nullptr) {} }; void Setup(const MklMatMulParams& params) { @@ -639,7 +653,6 @@ class MklMatMulPrimitive : public MklPrimitive { } struct MklMatMulContext context_; - engine cpu_engine_; }; template <typename T> @@ -707,8 +720,8 @@ void dnnl_gemm_batch(const std::vector<bool>& transa, const std::vector<int>& n, const std::vector<int>& k, const std::vector<float>& alpha, const T* a, const T* b, const std::vector<float>& beta, T* c, - const int group_count, - const std::vector<int>& group_size) { + const int group_count, const std::vector<int>& group_size, + OpKernelContext* ctx = nullptr) { // Current BatchMatMul support in Tensorflow is narrower than the one offered // by MKL and MKL-DNN. Current BatchMatMul support in Tensorflow uses only 1 // group of size equal to batch_size, and all MatMul parameters (m, n, k, @@ -757,13 +770,15 @@ void dnnl_gemm_batch(const std::vector<bool>& transa, MklMatMulPrimitiveFactory<T>::Get(params, 0); // Execute matmul primitive. - matmul_prim->Execute(a, b, c); + std::shared_ptr<stream> cpu_stream; + cpu_stream.reset(CreateStream(ctx, matmul_prim->GetEngine())); + matmul_prim->Execute(a, b, c, cpu_stream); } template <typename T> void dnnl_gemm(char transa, char transb, int64_t m, int64_t n, int64_t k, float alpha, const T* a, int64_t lda, const T* b, int64_t ldb, - float beta, T* c, int64_t ldc) { + float beta, T* c, int64_t ldc, OpKernelContext* ctx = nullptr) { using dims = mkldnn::memory::dims; // Prepare strides based on the transa and transb flags: transposed @@ -786,7 +801,9 @@ void dnnl_gemm(char transa, char transb, int64_t m, int64_t n, int64_t k, MklMatMulPrimitiveFactory<T>::Get(params, 0); // Execute matmul primitive. - matmul_prim->Execute(a, b, c); + std::shared_ptr<stream> cpu_stream; + cpu_stream.reset(CreateStream(ctx, matmul_prim->GetEngine())); + matmul_prim->Execute(a, b, c, cpu_stream); } } // anonymous namespace diff --git a/tensorflow/core/kernels/mkl_qmatmul_op.cc b/tensorflow/core/kernels/mkl_qmatmul_op.cc index cc7127e0559..e73f30db4da 100644 --- a/tensorflow/core/kernels/mkl_qmatmul_op.cc +++ b/tensorflow/core/kernels/mkl_qmatmul_op.cc @@ -245,8 +245,10 @@ class MklDnnQuantizedMatMulOp : public MklDnnMatMulOpBase<Tweight, Toutput> { Tinput* src_data = nullptr; if (IS_SRC_REORDER_NEEDED(src_md, matmul_fwd_pd, matmul_fwd)) { src.SetUsrMem(src_md, &src_tensor); - src.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA( - matmul_fwd_pd.get()->PRIMITIVE_DESC_SRC, this->cpu_engine_)); + src.CheckReorderToOpMem( + MEMORY_PD_WITHOUT_DATA(matmul_fwd_pd.get()->PRIMITIVE_DESC_SRC, + this->cpu_engine_), + context); src_data = static_cast<Tinput*>(src.GetOpMem().get_data_handle()); } else { src_data = static_cast<Tinput*>( @@ -279,8 +281,11 @@ class MklDnnQuantizedMatMulOp : public MklDnnMatMulOpBase<Tweight, Toutput> { if (!is_weight_cached) { weight.SetUsrMem(weight_md, &weight_tensor); - weight.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA( - matmul_fwd_pd.get()->PRIMITIVE_DESC_WEIGHTS, this->cpu_engine_)); + weight.CheckReorderToOpMem( + MEMORY_PD_WITHOUT_DATA( + matmul_fwd_pd.get()->PRIMITIVE_DESC_WEIGHTS, + this->cpu_engine_), + context); weight_data = static_cast<Tweight*>(weight.GetOpMem().get_data_handle()); } @@ -290,10 +295,13 @@ class MklDnnQuantizedMatMulOp : public MklDnnMatMulOpBase<Tweight, Toutput> { const_cast<Tweight*>(weight_tensor.flat<Tweight>().data())); } + std::shared_ptr<stream> cpu_stream; + cpu_stream.reset(CreateStream(context, matmul_fwd->GetEngine())); // Execute inner-product - Tbias* bias_data = this->GetBiasHandle(context, matmul_fwd_pd, - bias_tensor, weight_tensor); - matmul_fwd->Execute(src_data, weight_data, bias_data, dst_data); + Tbias* bias_data = this->GetBiasHandle( + context, matmul_fwd_pd, bias_tensor, weight_tensor, cpu_stream); + matmul_fwd->Execute(src_data, weight_data, bias_data, dst_data, + cpu_stream); } catch (mkldnn::error& e) { string error_msg = tensorflow::strings::StrCat( "Status: ", e.status, ", message: ", string(e.message), ", in file ", @@ -393,7 +401,8 @@ class MklDnnQuantizedMatMulOp : public MklDnnMatMulOpBase<Tweight, Toutput> { OpKernelContext* context, std::shared_ptr<mkldnn::inner_product_forward::primitive_desc>& mkldnn_matmul_fwd_pd, - const Tensor& bias_tensor, const Tensor& weight_tensor) { + const Tensor& bias_tensor, const Tensor& weight_tensor, + std::shared_ptr<stream> reorder_stream) { // If the bias is qint32, it means the bias is already converted offline. // and it can be added to matmul output directly. if (std::is_same<Tbias, qint32>::value) { @@ -449,7 +458,6 @@ class MklDnnQuantizedMatMulOp : public MklDnnMatMulOpBase<Tweight, Toutput> { std::vector<float> scales; scales.push_back(out_scale); mkldnn::primitive_attr bias_attr; - stream reorder_stream = CPU_STREAM(this->cpu_engine_); bias_attr.set_output_scales(0, scales); void* bias_buf = static_cast<void*>( @@ -468,14 +476,14 @@ class MklDnnQuantizedMatMulOp : public MklDnnMatMulOpBase<Tweight, Toutput> { {MKLDNN_ARG_FROM, *input_bias_}, { MKLDNN_ARG_TO, *scaled_bias_ }}; - net.at(0).execute(reorder_stream, reorder_net_args); + net.at(0).execute(*reorder_stream, reorder_net_args); #else auto reorder_desc = mkldnn::reorder::primitive_desc( input_bias_->get_primitive_desc(), scaled_bias_->get_primitive_desc(), bias_attr); net.push_back( mkldnn::reorder(reorder_desc, *input_bias_, *scaled_bias_)); - reorder_stream.submit(net).wait(); + reorder_stream->submit(net).wait(); #endif // ENABLE_MKLDNN_V1 return reinterpret_cast<Tbias*>(scaled_bias_->get_data_handle());