diff --git a/tensorflow/core/common_runtime/mkl_threadpool_device_test.cc b/tensorflow/core/common_runtime/mkl_threadpool_device_test.cc
index 5d583a8360b..c29752d3c2c 100644
--- a/tensorflow/core/common_runtime/mkl_threadpool_device_test.cc
+++ b/tensorflow/core/common_runtime/mkl_threadpool_device_test.cc
@@ -25,7 +25,7 @@ limitations under the License.
 
 namespace tensorflow {
 
-#ifdef _OPENMP
+#if defined(_OPENMP) && !defined(ENABLE_MKLDNN_THREADPOOL)
 TEST(MKLThreadPoolDeviceTest, TestOmpDefaults) {
   SessionOptions options;
   unsetenv("OMP_NUM_THREADS");
@@ -46,7 +46,7 @@ TEST(MKLThreadPoolDeviceTest, TestOmpPreSets) {
 
   EXPECT_EQ(omp_get_max_threads(), 314);
 }
-#endif  // _OPENMP
+#endif  // defined(_OPENMP) && !defined(ENABLE_MKLDNN_THREADPOOL)
 
 }  // namespace tensorflow
 
diff --git a/tensorflow/core/kernels/mkl_conv_ops_test.cc b/tensorflow/core/kernels/mkl_conv_ops_test.cc
index a055351337c..9d11b0fb006 100644
--- a/tensorflow/core/kernels/mkl_conv_ops_test.cc
+++ b/tensorflow/core/kernels/mkl_conv_ops_test.cc
@@ -28,7 +28,7 @@ limitations under the License.
 #include "tensorflow/core/public/session.h"
 
 #if defined(INTEL_MKL_DNN_ONLY)
-#include "third_party/intel_mkl_dnn/include/mkldnn.h"
+#include "mkldnn.hpp"
 #include "tensorflow/core/util/mkl_util.h"
 #endif
 
diff --git a/tensorflow/core/kernels/mkl_qmatmul_op.cc b/tensorflow/core/kernels/mkl_qmatmul_op.cc
index cc7127e0559..d8bbc130c55 100644
--- a/tensorflow/core/kernels/mkl_qmatmul_op.cc
+++ b/tensorflow/core/kernels/mkl_qmatmul_op.cc
@@ -91,12 +91,15 @@ limitations under the License.
 // https://software.intel.com/en-us/articles/lower-numerical-precision-deep-learning-inference-and-training
 #ifdef INTEL_MKL
 
+#include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/register_types.h"
 #include "tensorflow/core/kernels/fill_functor.h"
 #include "tensorflow/core/kernels/mkl_matmul_ops_common.h"
 #include "tensorflow/core/kernels/mkl_quantized_conv_ops.h"
 #include "tensorflow/core/kernels/no_op.h"
 #include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/util/mkl_threadpool.h"
+#include "tensorflow/core/util/work_sharder.h"
 
 namespace {
 enum {
@@ -428,6 +431,26 @@ class MklDnnQuantizedMatMulOp : public MklDnnMatMulOpBase<Tweight, Toutput> {
                     ((max_input - min_input) *
                      std::max(std::abs(max_weight), std::abs(min_weight)));
 
+#ifdef ENABLE_MKLDNN_THREADPOOL
+        auto parallel_func = [&](int64 start, int64 end) {
+          for (int64 j = start ; j < end; j++) {
+            int x = 0;
+            for (int64 i = 0; i < k; ++i) {
+              x += wt_buf[i * n + j];
+            }
+            comp_bias[j] =
+                ((bias_buf[j] * out_scale) + static_cast<float>(x * qa_amin));
+          }
+        };
+
+        const float kArithCost = 2.5f;
+        const float kMovCost = 1.0f;
+        float shard_cost = 4*kArithCost + kMovCost;
+        const DeviceBase::CpuWorkerThreads& worker_threads =
+            *(context->device()->tensorflow_cpu_worker_threads());
+        Shard(worker_threads.num_threads, worker_threads.workers, n, shard_cost,
+          parallel_func);
+#else
 #pragma omp parallel for schedule(static)
         for (int j = 0; j < n; ++j) {
           int x = 0;
@@ -437,7 +460,7 @@ class MklDnnQuantizedMatMulOp : public MklDnnMatMulOpBase<Tweight, Toutput> {
           comp_bias[j] =
               ((bias_buf[j] * out_scale) + static_cast<float>(x * qa_amin));
         }
-
+#endif  // ENABLE_MKLDNN_THREADPOOL
         return reinterpret_cast<Tbias*>(comp_bias_);
 
       } else if (mode_ == QUANTIZE_MODE_SCALED) {
diff --git a/tensorflow/core/kernels/mkl_quantized_conv_ops.h b/tensorflow/core/kernels/mkl_quantized_conv_ops.h
index fef2d837cf2..037a3a5f3ff 100644
--- a/tensorflow/core/kernels/mkl_quantized_conv_ops.h
+++ b/tensorflow/core/kernels/mkl_quantized_conv_ops.h
@@ -69,6 +69,19 @@ void MklQuantizationRangeForMultiplication(float min_a, float max_a,
   const float* max_b = max_b_vector.flat<float>().data();
   float* min_c = (*min_c_vector)->flat<float>().data();
   float* max_c = (*max_c_vector)->flat<float>().data();
+#ifdef ENABLE_MKLDNN_THREADPOOL
+  // TODO: Add eigen parallel_for
+  for(size_t n = 0; n < n_channel; ++n) {
+    float a_float_for_one_quant_level =
+        MklFloatForOneQuantizedLevel<T1>(min_a, max_a);
+    float b_float_for_one_quant_level =
+        MklFloatForOneQuantizedLevel<T2>(min_b[n], max_b[n]);
+    float c_float_for_one_quant_level =
+        a_float_for_one_quant_level * b_float_for_one_quant_level;
+    min_c[n] = c_float_for_one_quant_level * c_lowest;
+    max_c[n] = c_float_for_one_quant_level * c_highest;
+  }
+#else
 #pragma omp parallel for
   for (size_t n = 0; n < n_channel; ++n) {
     float a_float_for_one_quant_level =
@@ -80,6 +93,7 @@ void MklQuantizationRangeForMultiplication(float min_a, float max_a,
     min_c[n] = c_float_for_one_quant_level * c_lowest;
     max_c[n] = c_float_for_one_quant_level * c_highest;
   }
+#endif  // ENABLE_MKLDNN_THREADPOOL
 }
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/mkl_requantization_range_per_channel_op.cc b/tensorflow/core/kernels/mkl_requantization_range_per_channel_op.cc
index 767a6f1c397..0a19573d901 100644
--- a/tensorflow/core/kernels/mkl_requantization_range_per_channel_op.cc
+++ b/tensorflow/core/kernels/mkl_requantization_range_per_channel_op.cc
@@ -28,6 +28,7 @@ limitations under the License.
 #include "tensorflow/core/kernels/meta_support.h"
 #include "tensorflow/core/kernels/no_op.h"
 #include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/util/mkl_threadpool.h"
 #include "tensorflow/core/util/mkl_util.h"
 
 namespace tensorflow {
@@ -73,6 +74,26 @@ class MklRequantizationRangePerChannelOp : public OpKernel {
 
     // Find the ranges of each channel in parallel.
     float out_min_max = std::numeric_limits<float>::min();
+#ifdef ENABLE_MKLDNN_THREADPOOL
+    // TODO: Add eigen parallel_for
+    for(size_t i = 0; i < depth; ++i) {
+      Eigen::Tensor<qint32, 0, Eigen::RowMajor> min =
+          transposed_input.chip<0>(i).minimum();
+      Eigen::Tensor<qint32, 0, Eigen::RowMajor> max =
+          transposed_input.chip<0>(i).maximum();
+      const int32_t min_per_channel = min();
+      const int32_t max_per_channel = max();
+      const int32_t abs_max =
+          std::max(std::abs(min_per_channel), std::abs(max_per_channel));
+      float scale =
+          std::max(std::abs(input_min_data[i]), std::abs(input_max_data[i]));
+      ranges[i] =
+          scale * static_cast<float>(abs_max) / static_cast<float>(1L << 31);
+      if (min_per_channel < 0) is_non_negative = false;
+
+      out_min_max = std::max(out_min_max, ranges[i]);
+    }
+#else
 #pragma omp parallel for reduction(max : out_min_max)
     for (size_t i = 0; i < depth; ++i) {
       Eigen::Tensor<qint32, 0, Eigen::RowMajor> min =
@@ -92,6 +113,7 @@ class MklRequantizationRangePerChannelOp : public OpKernel {
       // Thread-local out_min_max.
       out_min_max = std::max(out_min_max, ranges[i]);
     }
+#endif  // ENABLE_MKLDNN_THREADPOOL
     // All local out_min_max gets max-reduced into one global out_min_max at
     // the end of the loop by specifying reduction(max:out_min_max) along with
     // omp parallel for.
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index 9a780839be3..5dc5877367b 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -354,9 +354,7 @@ def tf_copts(
     )
 
 def tf_openmp_copts():
-    # TODO(intel-mkl): Remove -fopenmp for threadpool after removing all
-    # omp pragmas in tensorflow/core.
-    return if_mkl_lnx_x64(["-fopenmp"]) + if_mkldnn_threadpool(["-fopenmp"])
+    return (if_mkl_lnx_x64(["-fopenmp"]) + if_mkldnn_threadpool(["-fno-openmp"]))
 
 def tfe_xla_copts():
     return select({