diff --git a/tensorflow/core/kernels/mkl_matmul_op.cc b/tensorflow/core/kernels/mkl_matmul_op.cc
index 86193901c96..fb197618fb0 100644
--- a/tensorflow/core/kernels/mkl_matmul_op.cc
+++ b/tensorflow/core/kernels/mkl_matmul_op.cc
@@ -62,11 +62,11 @@ class MklMatMulOp : public OpKernel {
     dim_pair[0].first = transpose_a_ ? 0 : 1;
     dim_pair[0].second = transpose_b_ ? 1 : 0;
 
-    OP_REQUIRES(
-        ctx, a.dim_size(dim_pair[0].first) == b.dim_size(dim_pair[0].second),
-        errors::InvalidArgument(
-            "Matrix size-incompatible: In[0]: ", a.shape().DebugString(),
-            ", In[1]: ", b.shape().DebugString()));
+    OP_REQUIRES(ctx,
+                a.dim_size(dim_pair[0].first) == b.dim_size(dim_pair[0].second),
+                errors::InvalidArgument("Matrix size-incompatible: In[0]: ",
+                                        a.shape().DebugString(), ", In[1]: ",
+                                        b.shape().DebugString()));
     int a_dim_remaining = 1 - dim_pair[0].first;
     int b_dim_remaining = 1 - dim_pair[0].second;
     TensorShape out_shape(
@@ -158,9 +158,17 @@ class MklMatMulOp : public OpKernel {
 #ifdef ENABLE_MKLDNN_V1
     char char_transa = transa ? 'T' : 'N';
     char char_transb = transb ? 'T' : 'N';
-    VLOG(2) << "MKL DNN SGEMM CALLED";
+    VLOG(2) << "MKL DNN SGEMM called";
+#ifdef ENABLE_MKLDNN_THREADPOOL
+    auto eigen_tp =
+        MklDnnThreadPoolWrapper::GetInstance().CreateThreadPoolPtr(ctx);
+
+    dnnl_sgemm_tp(char_transa, char_transb, m, n, k, alpha, a, lda, b, ldb,
+                  beta, c, ldc, eigen_tp);
+#else
     dnnl_sgemm(char_transa, char_transb, m, n, k, alpha, a, lda, b, ldb, beta,
                c, ldc);
+#endif  // ENABLE_MKLDNN_THREADPOOL
 #else
     // TODO(intel-tf): Remove this after TF2.3 fork.
     cblas_sgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans,
@@ -182,7 +190,7 @@ class MklMatMulOp : public OpKernel {
 #ifdef ENABLE_MKLDNN_V1
     const char ftrans[] = {'N', 'T', 'C'};
     dnnl_gemm<bfloat16>(ftrans[index_transa], ftrans[index_transb], m, n, k,
-                        alpha, a, lda, b, ldb, beta, c, ldc);
+                        alpha, a, lda, b, ldb, beta, c, ldc, ctx);
 #else
     Tensor c_float;
     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, {m, n}, &c_float));
diff --git a/tensorflow/core/kernels/mkl_matmul_op_fused.cc b/tensorflow/core/kernels/mkl_matmul_op_fused.cc
index 99a2cfc214b..f3608ef72a8 100644
--- a/tensorflow/core/kernels/mkl_matmul_op_fused.cc
+++ b/tensorflow/core/kernels/mkl_matmul_op_fused.cc
@@ -86,11 +86,10 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
     const int k = src_tf_shape.dim_size(dim_pair[0]);
     const int channel = weight_tf_shape.dim_size(1 - dim_pair[1]);
 
-    OP_REQUIRES(
-        ctx, k == weight_tf_shape.dim_size(dim_pair[1]),
-        errors::InvalidArgument(
-            "Matrix size-incompatible: In[0]: ", src_tf_shape.DebugString(),
-            ", In[1]: ", weight_tf_shape.DebugString()));
+    OP_REQUIRES(ctx, k == weight_tf_shape.dim_size(dim_pair[1]),
+                errors::InvalidArgument("Matrix size-incompatible: In[0]: ",
+                                        src_tf_shape.DebugString(), ", In[1]: ",
+                                        weight_tf_shape.DebugString()));
     OP_REQUIRES(ctx, bias_tensor.shape().dim_size(0) == channel,
                 errors::InvalidArgument(
                     "Must provide as many biases as the channel size: ",
@@ -159,8 +158,10 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
 
       if (IS_SRC_REORDER_NEEDED(src_md, matmul_pd, matmul_prim)) {
         src_mkl.SetUsrMem(src_md, src_data);
-        src_mkl.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
-            matmul_pd.get()->PRIMITIVE_DESC_SRC, this->cpu_engine_));
+        src_mkl.CheckReorderToOpMem(
+            MEMORY_PD_WITHOUT_DATA(matmul_pd.get()->PRIMITIVE_DESC_SRC,
+                                   this->cpu_engine_),
+            ctx);
         src_data = reinterpret_cast<T*>(src_mkl.GetOpMem().get_data_handle());
       }
 
@@ -191,19 +192,23 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
           weight_data = cached_weight_data;
         } else {
           weight_mkl.SetUsrMem(weight_md, weight_data);
-          weight_mkl.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
-              matmul_pd.get()->PRIMITIVE_DESC_WEIGHTS, this->cpu_engine_));
+          weight_mkl.CheckReorderToOpMem(
+              MEMORY_PD_WITHOUT_DATA(matmul_pd.get()->PRIMITIVE_DESC_WEIGHTS,
+                                     this->cpu_engine_),
+              ctx);
           weight_data =
               reinterpret_cast<T*>(weight_mkl.GetOpMem().get_data_handle());
         }
       }
-
+      std::shared_ptr<stream> cpu_stream;
+      cpu_stream.reset(CreateStream(ctx, matmul_prim->GetEngine()));
       // Execute fused matmul op.
-      matmul_prim->Execute(src_data, weight_data, bias_data, dst_data);
+      matmul_prim->Execute(src_data, weight_data, bias_data, dst_data,
+                           cpu_stream);
     } catch (mkldnn::error& e) {
-      string error_msg = "Status: " + std::to_string(e.status) +
-                         ", message: " + string(e.message) + ", in file " +
-                         string(__FILE__) + ":" + std::to_string(__LINE__);
+      string error_msg = "Status: " + std::to_string(e.status) + ", message: " +
+                         string(e.message) + ", in file " + string(__FILE__) +
+                         ":" + std::to_string(__LINE__);
       OP_REQUIRES_OK(
           ctx, errors::Aborted("Operation received an exception:", error_msg));
     }
diff --git a/tensorflow/core/kernels/mkl_matmul_ops_common.h b/tensorflow/core/kernels/mkl_matmul_ops_common.h
index d3a05a4a6d2..d7af614ad04 100644
--- a/tensorflow/core/kernels/mkl_matmul_ops_common.h
+++ b/tensorflow/core/kernels/mkl_matmul_ops_common.h
@@ -75,8 +75,7 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive {
  public:
   explicit MklDnnMatMulFwdPrimitive(
       const MklDnnMatMulFwdParams& matmulFwdParams)
-      : cpu_engine_(ENGINE_CPU, 0) {
-    context_.fwd_stream.reset(new CPU_STREAM(cpu_engine_));
+      : MklPrimitive(engine(ENGINE_CPU, 0)) {
     // Create matmul primitive
     if (context_.matmul_fwd == nullptr) {
       Setup(matmulFwdParams);
@@ -91,7 +90,18 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive {
   //  - bias_data: input data buffer of bias
   //  - dst_data: output data buffer of dst
   void Execute(const Tinput* src_data, const Tweight* weight_data,
-               const Tbias* bias_data, Toutput* dst_data) {
+               const Tbias* bias_data, Toutput* dst_data,
+               std::shared_ptr<stream> fwd_stream) {
+#ifdef ENABLE_MKLDNN_THREADPOOL
+    context_.src_mem->set_data_handle(
+        static_cast<void*>(const_cast<Tinput*>(src_data)), *fwd_stream);
+    context_.weight_mem->set_data_handle(
+        static_cast<void*>(const_cast<Tweight*>(weight_data)), *fwd_stream);
+    context_.bias_mem->set_data_handle(
+        static_cast<void*>(const_cast<Tbias*>(bias_data)));
+    context_.dst_mem->set_data_handle(static_cast<void*>(dst_data),
+                                      *fwd_stream);
+#else
     context_.src_mem->set_data_handle(
         static_cast<void*>(const_cast<Tinput*>(src_data)));
     context_.weight_mem->set_data_handle(
@@ -99,12 +109,12 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive {
     context_.bias_mem->set_data_handle(
         static_cast<void*>(const_cast<Tbias*>(bias_data)));
     context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
+#endif  // ENABLE_MKLDNN_THREADPOOL
 
 #ifdef ENABLE_MKLDNN_V1
-    execute_primitives(context_.fwd_primitives, context_.fwd_stream,
-                       context_.net_args);
+    execute_primitives(context_.fwd_primitives, fwd_stream, context_.net_args);
 #else
-    context_.fwd_stream->submit(context_.fwd_primitives);
+    fwd_stream->submit(context_.fwd_primitives);
 #endif  // ENABLE_MKLDNN_V1
 
     // After execution, set data handle back
@@ -153,7 +163,6 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive {
 
     // Inner-product primitive.
     std::shared_ptr<mkldnn::primitive> matmul_fwd;
-    std::shared_ptr<mkldnn::stream> fwd_stream;
     std::vector<mkldnn::primitive> fwd_primitives;
 
 #ifdef ENABLE_MKLDNN_V1
@@ -176,8 +185,7 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive {
           weight_md(nullptr),
           bias_md(nullptr),
           dst_md(nullptr),
-          matmul_fwd(nullptr),
-          fwd_stream(nullptr) {
+          matmul_fwd(nullptr) {
     }
   };
 
@@ -292,7 +300,6 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive {
   }
 
   struct MklDnnMatMulFwdContext context_;
-  engine cpu_engine_;
 };
 
 template <typename T, typename Tinput, typename Tweight, typename Tbias,
@@ -439,8 +446,10 @@ class MklDnnMatMulOpBase : public OpKernel {
 
     // reorder and cache the weight
     weight.SetUsrMem(weight_md, &weight_tensor);
-    weight.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
-        matmul_fwd_pd.get()->PRIMITIVE_DESC_WEIGHTS, cpu_engine_));
+    weight.CheckReorderToOpMem(
+        MEMORY_PD_WITHOUT_DATA(matmul_fwd_pd.get()->PRIMITIVE_DESC_WEIGHTS,
+                               cpu_engine_),
+        context);
     weight_data = static_cast<Tweight*>(weight.GetOpMem().get_data_handle());
 
     Tensor* weight_tensor_ptr = nullptr;
@@ -544,21 +553,28 @@ template <typename T>
 class MklMatMulPrimitive : public MklPrimitive {
  public:
   explicit MklMatMulPrimitive(const MklMatMulParams& params)
-      : cpu_engine_(ENGINE_CPU, 0) {
-    context_.stream.reset(new CPU_STREAM(cpu_engine_));
+      : MklPrimitive(engine(ENGINE_CPU, 0)) {
     // Create matmul primitive
     Setup(params);
   }
 
   ~MklMatMulPrimitive() {}
 
-  void Execute(const T* a_data, const T* b_data, T* c_data) {
+  void Execute(const T* a_data, const T* b_data, T* c_data,
+               std::shared_ptr<stream> stream) {
+#ifdef ENABLE_MKLDNN_THREADPOOL
+    context_.a_mem->set_data_handle(static_cast<void*>(const_cast<T*>(a_data)),
+                                    *stream);
+    context_.b_mem->set_data_handle(static_cast<void*>(const_cast<T*>(b_data)),
+                                    *stream);
+    context_.c_mem->set_data_handle(static_cast<void*>(const_cast<T*>(c_data)),
+                                    *stream);
+#else
     context_.a_mem->set_data_handle(static_cast<void*>(const_cast<T*>(a_data)));
     context_.b_mem->set_data_handle(static_cast<void*>(const_cast<T*>(b_data)));
     context_.c_mem->set_data_handle(static_cast<void*>(const_cast<T*>(c_data)));
-
-    execute_primitives(context_.matmul_primitives, context_.stream,
-                       context_.net_args);
+#endif  // ENABLE_MKLDNN_THREADPOOL
+    execute_primitives(context_.matmul_primitives, stream, context_.net_args);
 
     // After execution, set data handle back
     context_.a_mem->set_data_handle(DummyData);
@@ -584,7 +600,6 @@ class MklMatMulPrimitive : public MklPrimitive {
     std::shared_ptr<mkldnn::memory::desc> c_md;
 
     // MatMul primitive.
-    std::shared_ptr<mkldnn::stream> stream;
     std::vector<mkldnn::primitive> matmul_primitives;
     std::vector<std::unordered_map<int, memory>> net_args;
 
@@ -596,8 +611,7 @@ class MklMatMulPrimitive : public MklPrimitive {
           prim_desc(nullptr),
           a_md(nullptr),
           b_md(nullptr),
-          c_md(nullptr),
-          stream(nullptr) {}
+          c_md(nullptr) {}
   };
 
   void Setup(const MklMatMulParams& params) {
@@ -639,7 +653,6 @@ class MklMatMulPrimitive : public MklPrimitive {
   }
 
   struct MklMatMulContext context_;
-  engine cpu_engine_;
 };
 
 template <typename T>
@@ -707,8 +720,8 @@ void dnnl_gemm_batch(const std::vector<bool>& transa,
                      const std::vector<int>& n, const std::vector<int>& k,
                      const std::vector<float>& alpha, const T* a, const T* b,
                      const std::vector<float>& beta, T* c,
-                     const int group_count,
-                     const std::vector<int>& group_size) {
+                     const int group_count, const std::vector<int>& group_size,
+                     OpKernelContext* ctx = nullptr) {
   // Current BatchMatMul support in Tensorflow is narrower than the one offered
   // by MKL and MKL-DNN. Current BatchMatMul support in Tensorflow uses only 1
   // group of size equal to batch_size, and all MatMul parameters (m, n, k,
@@ -757,13 +770,15 @@ void dnnl_gemm_batch(const std::vector<bool>& transa,
       MklMatMulPrimitiveFactory<T>::Get(params, 0);
 
   // Execute matmul primitive.
-  matmul_prim->Execute(a, b, c);
+  std::shared_ptr<stream> cpu_stream;
+  cpu_stream.reset(CreateStream(ctx, matmul_prim->GetEngine()));
+  matmul_prim->Execute(a, b, c, cpu_stream);
 }
 
 template <typename T>
 void dnnl_gemm(char transa, char transb, int64_t m, int64_t n, int64_t k,
                float alpha, const T* a, int64_t lda, const T* b, int64_t ldb,
-               float beta, T* c, int64_t ldc) {
+               float beta, T* c, int64_t ldc, OpKernelContext* ctx = nullptr) {
   using dims = mkldnn::memory::dims;
 
   // Prepare strides based on the transa and transb flags: transposed
@@ -786,7 +801,9 @@ void dnnl_gemm(char transa, char transb, int64_t m, int64_t n, int64_t k,
       MklMatMulPrimitiveFactory<T>::Get(params, 0);
 
   // Execute matmul primitive.
-  matmul_prim->Execute(a, b, c);
+  std::shared_ptr<stream> cpu_stream;
+  cpu_stream.reset(CreateStream(ctx, matmul_prim->GetEngine()));
+  matmul_prim->Execute(a, b, c, cpu_stream);
 }
 
 }  // anonymous namespace
diff --git a/tensorflow/core/kernels/mkl_qmatmul_op.cc b/tensorflow/core/kernels/mkl_qmatmul_op.cc
index cc7127e0559..e73f30db4da 100644
--- a/tensorflow/core/kernels/mkl_qmatmul_op.cc
+++ b/tensorflow/core/kernels/mkl_qmatmul_op.cc
@@ -245,8 +245,10 @@ class MklDnnQuantizedMatMulOp : public MklDnnMatMulOpBase<Tweight, Toutput> {
       Tinput* src_data = nullptr;
       if (IS_SRC_REORDER_NEEDED(src_md, matmul_fwd_pd, matmul_fwd)) {
         src.SetUsrMem(src_md, &src_tensor);
-        src.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
-            matmul_fwd_pd.get()->PRIMITIVE_DESC_SRC, this->cpu_engine_));
+        src.CheckReorderToOpMem(
+            MEMORY_PD_WITHOUT_DATA(matmul_fwd_pd.get()->PRIMITIVE_DESC_SRC,
+                                   this->cpu_engine_),
+            context);
         src_data = static_cast<Tinput*>(src.GetOpMem().get_data_handle());
       } else {
         src_data = static_cast<Tinput*>(
@@ -279,8 +281,11 @@ class MklDnnQuantizedMatMulOp : public MklDnnMatMulOpBase<Tweight, Toutput> {
 
         if (!is_weight_cached) {
           weight.SetUsrMem(weight_md, &weight_tensor);
-          weight.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
-              matmul_fwd_pd.get()->PRIMITIVE_DESC_WEIGHTS, this->cpu_engine_));
+          weight.CheckReorderToOpMem(
+              MEMORY_PD_WITHOUT_DATA(
+                  matmul_fwd_pd.get()->PRIMITIVE_DESC_WEIGHTS,
+                  this->cpu_engine_),
+              context);
           weight_data =
               static_cast<Tweight*>(weight.GetOpMem().get_data_handle());
         }
@@ -290,10 +295,13 @@ class MklDnnQuantizedMatMulOp : public MklDnnMatMulOpBase<Tweight, Toutput> {
             const_cast<Tweight*>(weight_tensor.flat<Tweight>().data()));
       }
 
+      std::shared_ptr<stream> cpu_stream;
+      cpu_stream.reset(CreateStream(context, matmul_fwd->GetEngine()));
       // Execute inner-product
-      Tbias* bias_data = this->GetBiasHandle(context, matmul_fwd_pd,
-                                             bias_tensor, weight_tensor);
-      matmul_fwd->Execute(src_data, weight_data, bias_data, dst_data);
+      Tbias* bias_data = this->GetBiasHandle(
+          context, matmul_fwd_pd, bias_tensor, weight_tensor, cpu_stream);
+      matmul_fwd->Execute(src_data, weight_data, bias_data, dst_data,
+                          cpu_stream);
     } catch (mkldnn::error& e) {
       string error_msg = tensorflow::strings::StrCat(
           "Status: ", e.status, ", message: ", string(e.message), ", in file ",
@@ -393,7 +401,8 @@ class MklDnnQuantizedMatMulOp : public MklDnnMatMulOpBase<Tweight, Toutput> {
       OpKernelContext* context,
       std::shared_ptr<mkldnn::inner_product_forward::primitive_desc>&
           mkldnn_matmul_fwd_pd,
-      const Tensor& bias_tensor, const Tensor& weight_tensor) {
+      const Tensor& bias_tensor, const Tensor& weight_tensor,
+      std::shared_ptr<stream> reorder_stream) {
     // If the bias is qint32, it means the bias is already converted offline.
     // and it can be added to matmul output directly.
     if (std::is_same<Tbias, qint32>::value) {
@@ -449,7 +458,6 @@ class MklDnnQuantizedMatMulOp : public MklDnnMatMulOpBase<Tweight, Toutput> {
         std::vector<float> scales;
         scales.push_back(out_scale);
         mkldnn::primitive_attr bias_attr;
-        stream reorder_stream = CPU_STREAM(this->cpu_engine_);
         bias_attr.set_output_scales(0, scales);
 
         void* bias_buf = static_cast<void*>(
@@ -468,14 +476,14 @@ class MklDnnQuantizedMatMulOp : public MklDnnMatMulOpBase<Tweight, Toutput> {
             {MKLDNN_ARG_FROM, *input_bias_},
             { MKLDNN_ARG_TO,
               *scaled_bias_ }};
-        net.at(0).execute(reorder_stream, reorder_net_args);
+        net.at(0).execute(*reorder_stream, reorder_net_args);
 #else
         auto reorder_desc = mkldnn::reorder::primitive_desc(
             input_bias_->get_primitive_desc(),
             scaled_bias_->get_primitive_desc(), bias_attr);
         net.push_back(
             mkldnn::reorder(reorder_desc, *input_bias_, *scaled_bias_));
-        reorder_stream.submit(net).wait();
+        reorder_stream->submit(net).wait();
 #endif  // ENABLE_MKLDNN_V1
 
         return reinterpret_cast<Tbias*>(scaled_bias_->get_data_handle());