From c63d21b0bfc534b6377b332e9d2ba2abbdb7e0eb Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Tue, 25 Sep 2018 22:57:54 -0700
Subject: [PATCH] Adds a build flag to enable MKL (mkl_enabled=true).

PiperOrigin-RevId: 214557082
---
 tensorflow/contrib/cmake/CMakeLists.txt       |  2 +-
 ...direct_session_with_tracking_alloc_test.cc |  8 ++--
 .../common_runtime/mkl_cpu_allocator_test.cc  |  4 +-
 .../core/common_runtime/threadpool_device.cc  |  5 ++-
 tensorflow/core/graph/mkl_layout_pass.cc      |  4 ++
 tensorflow/core/graph/mkl_layout_pass_test.cc |  4 +-
 .../core/graph/mkl_tfconversion_pass.cc       |  2 +
 .../core/graph/mkl_tfconversion_pass_test.cc  |  4 +-
 .../core/kernels/batch_matmul_op_complex.cc   | 10 +++--
 .../core/kernels/batch_matmul_op_real.cc      |  9 +++-
 tensorflow/core/kernels/cwise_ops_common.cc   |  4 +-
 .../core/kernels/gather_nd_op_cpu_impl.h      |  6 +--
 tensorflow/core/kernels/matmul_op.cc          |  8 ++--
 .../core/kernels/mkl_batch_matmul_op.cc       |  2 +
 tensorflow/core/kernels/mkl_matmul_op.cc      |  6 ++-
 tensorflow/core/kernels/slice_op.cc           | 26 +++++-------
 tensorflow/core/kernels/transpose_op.cc       | 10 ++---
 tensorflow/core/util/port.cc                  |  4 +-
 tensorflow/tensorflow.bzl                     |  3 ++
 third_party/mkl/BUILD                         | 23 +++++++----
 third_party/mkl/build_defs.bzl                | 41 ++++++++++++++-----
 third_party/mkl_dnn/BUILD                     |  6 +--
 third_party/mkl_dnn/build_defs.bzl            |  2 +-
 tools/bazel.rc                                |  5 ++-
 24 files changed, 123 insertions(+), 75 deletions(-)

diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt
index ebcabb42230..c6d6f04168b 100644
--- a/tensorflow/contrib/cmake/CMakeLists.txt
+++ b/tensorflow/contrib/cmake/CMakeLists.txt
@@ -353,7 +353,7 @@ endif()
 
 # MKL Support
 if (tensorflow_ENABLE_MKL_SUPPORT)
-  add_definitions(-DINTEL_MKL -DEIGEN_USE_VML)
+  add_definitions(-DINTEL_MKL -DEIGEN_USE_VML -DENABLE_MKL)
   include(mkl)
   list(APPEND tensorflow_EXTERNAL_LIBRARIES ${mkl_STATIC_LIBRARIES})
   list(APPEND tensorflow_EXTERNAL_DEPENDENCIES mkl_copy_shared_to_destination)
diff --git a/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc b/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc
index 2ed4f69f90b..efd6185f8b7 100644
--- a/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc
+++ b/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc
@@ -108,7 +108,7 @@ TEST(DirectSessionWithTrackingAllocTest, CostModelTest) {
         EXPECT_EQ(2, shape.dim(0).size());
         EXPECT_EQ(1, shape.dim(1).size());
         if (node->name() == y->name()) {
-#ifdef INTEL_MKL
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
           // if MKL is used, it goes through various additional
           // graph rewrite pass. In TF, everytime a graph pass
           // happens, "constant" nodes are allocated
@@ -120,13 +120,13 @@ TEST(DirectSessionWithTrackingAllocTest, CostModelTest) {
           EXPECT_EQ(29, cm->AllocationId(node, 0));
 #else
           EXPECT_EQ(21, cm->AllocationId(node, 0));
-#endif
+#endif  // INTEL_MKL && ENABLE_MKL
         } else {
-#ifdef INTEL_MKL
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
           EXPECT_EQ(30, cm->AllocationId(node, 0));
 #else
           EXPECT_EQ(22, cm->AllocationId(node, 0));
-#endif
+#endif  // INTEL_MKL && ENABLE_MKL
         }
       }
       EXPECT_LE(0, cm->MaxExecutionTime(node));
diff --git a/tensorflow/core/common_runtime/mkl_cpu_allocator_test.cc b/tensorflow/core/common_runtime/mkl_cpu_allocator_test.cc
index a67411cd2e2..e08ab576385 100644
--- a/tensorflow/core/common_runtime/mkl_cpu_allocator_test.cc
+++ b/tensorflow/core/common_runtime/mkl_cpu_allocator_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#ifdef INTEL_MKL
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
 
 #include "tensorflow/core/common_runtime/mkl_cpu_allocator.h"
 
@@ -50,4 +50,4 @@ TEST(MKLBFCAllocatorTest, TestMaxLimit) {
 
 }  // namespace tensorflow
 
-#endif  // INTEL_MKL
+#endif  // INTEL_MKL && ENABLE_MKL
diff --git a/tensorflow/core/common_runtime/threadpool_device.cc b/tensorflow/core/common_runtime/threadpool_device.cc
index 0fbc20b34ba..8587d1783ac 100644
--- a/tensorflow/core/common_runtime/threadpool_device.cc
+++ b/tensorflow/core/common_runtime/threadpool_device.cc
@@ -113,8 +113,11 @@ class MklCPUAllocatorFactory : public AllocatorFactory {
   }
 };
 
+#ifdef ENABLE_MKL
 REGISTER_MEM_ALLOCATOR("MklCPUAllocator", 200, MklCPUAllocatorFactory);
+#endif  // ENABLE_MKL
+
 }  // namespace
-#endif
+#endif  // INTEL_MKL
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc
index f5b01058628..37b88f17282 100644
--- a/tensorflow/core/graph/mkl_layout_pass.cc
+++ b/tensorflow/core/graph/mkl_layout_pass.cc
@@ -977,7 +977,9 @@ std::vector<MklLayoutRewritePass::ContextInfo*> MklLayoutRewritePass::cinfo_;
 // nodes. Do not change the ordering of the Mkl passes.
 const OptimizationPassRegistry::Grouping kMklLayoutRewritePassGroup =
     OptimizationPassRegistry::POST_PARTITIONING;
+#ifdef ENABLE_MKL
 REGISTER_OPTIMIZATION(kMklLayoutRewritePassGroup, 1, MklLayoutRewritePass);
+#endif  // ENABLE_MKL
 
 //////////////////////////////////////////////////////////////////////////
 //           Helper functions for creating new node
@@ -3150,7 +3152,9 @@ MklLayoutRewritePass::ConstStringsInfo MklLayoutRewritePass::csinfo_;
 // nodes. Do not change the ordering of the Mkl passes.
 const OptimizationPassRegistry::Grouping kMklLayoutRewritePassGroup =
     OptimizationPassRegistry::POST_PARTITIONING;
+#ifdef ENABLE_MKL
 REGISTER_OPTIMIZATION(kMklLayoutRewritePassGroup, 1, MklLayoutRewritePass);
+#endif  // ENABLE_MKL
 
 //////////////////////////////////////////////////////////////////////////
 //           Helper functions for creating new node
diff --git a/tensorflow/core/graph/mkl_layout_pass_test.cc b/tensorflow/core/graph/mkl_layout_pass_test.cc
index e8bac847e58..f42a4ee98bf 100644
--- a/tensorflow/core/graph/mkl_layout_pass_test.cc
+++ b/tensorflow/core/graph/mkl_layout_pass_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#ifdef INTEL_MKL
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
 
 #include "tensorflow/core/graph/mkl_layout_pass.h"
 #include "tensorflow/core/graph/mkl_graph_util.h"
@@ -3586,4 +3586,4 @@ BENCHMARK(BM_MklLayoutRewritePass)->Arg(1000)->Arg(10000);
 
 }  // namespace tensorflow
 
-#endif /* INTEL_MKL */
+#endif  // INTEL_MKL && ENABLE_MKL
diff --git a/tensorflow/core/graph/mkl_tfconversion_pass.cc b/tensorflow/core/graph/mkl_tfconversion_pass.cc
index b67a321fc1b..8c5ffd71a32 100644
--- a/tensorflow/core/graph/mkl_tfconversion_pass.cc
+++ b/tensorflow/core/graph/mkl_tfconversion_pass.cc
@@ -133,7 +133,9 @@ class MklToTfConversionPass : public GraphOptimizationPass {
 // complete picture of inputs and outputs of the nodes in the graphs.
 const OptimizationPassRegistry::Grouping kMklTfConvPassGroup =
     OptimizationPassRegistry::POST_PARTITIONING;
+#ifdef ENABLE_MKL
 REGISTER_OPTIMIZATION(kMklTfConvPassGroup, 2, MklToTfConversionPass);
+#endif  // ENABLE_MKL
 
 Status MklToTfConversionPass::InsertConversionNodeOnEdge(
     std::unique_ptr<Graph>* g, Edge* e) {
diff --git a/tensorflow/core/graph/mkl_tfconversion_pass_test.cc b/tensorflow/core/graph/mkl_tfconversion_pass_test.cc
index ebcb6de551e..319437a8016 100644
--- a/tensorflow/core/graph/mkl_tfconversion_pass_test.cc
+++ b/tensorflow/core/graph/mkl_tfconversion_pass_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#ifdef INTEL_MKL
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
 
 #include "tensorflow/core/graph/mkl_tfconversion_pass.h"
 #include "tensorflow/core/graph/mkl_graph_util.h"
@@ -304,4 +304,4 @@ BENCHMARK(BM_RunMklToTfConversionPass)->Arg(1000)->Arg(10000);
 }  // namespace
 }  // namespace tensorflow
 
-#endif /* INTEL_MKL */
+#endif  // INTEL_MKL && ENABLE_MKL
diff --git a/tensorflow/core/kernels/batch_matmul_op_complex.cc b/tensorflow/core/kernels/batch_matmul_op_complex.cc
index 54c45bfe639..f48bd0c3187 100644
--- a/tensorflow/core/kernels/batch_matmul_op_complex.cc
+++ b/tensorflow/core/kernels/batch_matmul_op_complex.cc
@@ -17,14 +17,18 @@ limitations under the License.
 
 namespace tensorflow {
 
-#if !defined(INTEL_MKL) || defined(INTEL_MKL_DNN_ONLY)
+// MKL_ML registers its own complex64/128 kernels in mkl_batch_matmul_op.cc
+// if defined(INTEL_MKL) && !defined(INTEL_MKL_DNN_ONLY) && defined(ENABLE_MKL).
+// Anything else (the complement) should register the TF ones.
+// (MKL-DNN doesn't implement these kernels either.)
+#if !defined(INTEL_MKL) || defined(INTEL_MKL_DNN_ONLY) || !defined(ENABLE_MKL)
 TF_CALL_complex64(REGISTER_BATCH_MATMUL_CPU);
 TF_CALL_complex128(REGISTER_BATCH_MATMUL_CPU);
-#endif
+#endif  // !INTEL_MKL || INTEL_MKL_DNN_ONLY || !ENABLE_MKL
 
 #if GOOGLE_CUDA
 TF_CALL_complex64(REGISTER_BATCH_MATMUL_GPU);
 TF_CALL_complex128(REGISTER_BATCH_MATMUL_GPU);
-#endif
+#endif  // GOOGLE_CUDA
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/batch_matmul_op_real.cc b/tensorflow/core/kernels/batch_matmul_op_real.cc
index 584b507c700..25ae795d8e7 100644
--- a/tensorflow/core/kernels/batch_matmul_op_real.cc
+++ b/tensorflow/core/kernels/batch_matmul_op_real.cc
@@ -21,10 +21,15 @@ limitations under the License.
 
 namespace tensorflow {
 
-#if !defined(INTEL_MKL) || defined(INTEL_MKL_DNN_ONLY)
+// MKL_ML registers its own float and double kernels in mkl_batch_matmul_op.cc
+// if defined(INTEL_MKL) && !defined(INTEL_MKL_DNN_ONLY) && defined(ENABLE_MKL).
+// Anything else (the complement) should register the TF ones.
+// (MKL-DNN doesn't implement these kernels either.)
+#if !defined(INTEL_MKL) || defined(INTEL_MKL_DNN_ONLY) || !defined(ENABLE_MKL)
 TF_CALL_float(REGISTER_BATCH_MATMUL_CPU);
 TF_CALL_double(REGISTER_BATCH_MATMUL_CPU);
-#endif
+#endif  // !INTEL_MKL || INTEL_MKL_DNN_ONLY || !ENABLE_MKL
+
 TF_CALL_half(REGISTER_BATCH_MATMUL_CPU);
 TF_CALL_int32(REGISTER_BATCH_MATMUL_CPU);
 
diff --git a/tensorflow/core/kernels/cwise_ops_common.cc b/tensorflow/core/kernels/cwise_ops_common.cc
index 980edffceb3..8ad3b4d1fc9 100644
--- a/tensorflow/core/kernels/cwise_ops_common.cc
+++ b/tensorflow/core/kernels/cwise_ops_common.cc
@@ -20,9 +20,9 @@ namespace tensorflow {
 BinaryOpShared::BinaryOpShared(OpKernelConstruction* ctx, DataType out,
                                DataType in)
     : OpKernel(ctx) {
-#ifndef INTEL_MKL
+#if !defined(INTEL_MKL) || !defined(ENABLE_MKL)
   OP_REQUIRES_OK(ctx, ctx->MatchSignature({in, in}, {out}));
-#endif
+#endif  // !INTEL_MKL || !ENABLE_MKL
 }
 
 void BinaryOpShared::SetUnimplementedError(OpKernelContext* ctx) {
diff --git a/tensorflow/core/kernels/gather_nd_op_cpu_impl.h b/tensorflow/core/kernels/gather_nd_op_cpu_impl.h
index 277ee2be02d..1c78de253e7 100644
--- a/tensorflow/core/kernels/gather_nd_op_cpu_impl.h
+++ b/tensorflow/core/kernels/gather_nd_op_cpu_impl.h
@@ -114,7 +114,7 @@ struct GatherNdSlice<CPUDevice, T, Index, IXDIM> {
     generator::GatherNdSliceGenerator<T, Index, IXDIM> gather_nd_generator(
         slice_size, Tindices, Tparams, Tout, &error_loc);
 
-#ifdef INTEL_MKL
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
 // Eigen implementation below is not highly performant. gather_nd_generator
 // does not seem to be called in parallel, leading to very poor performance.
 // Additionally, since it uses scalar (Tscratch) to invoke 'generate', it
@@ -126,12 +126,12 @@ struct GatherNdSlice<CPUDevice, T, Index, IXDIM> {
       const Eigen::array<Eigen::DenseIndex, 1> loc{i};
       gather_nd_generator(loc);
     }
-#else  // INTEL_MKL
+#else   // INTEL_MKL && ENABLE_MKL
     Tscratch.device(d) = Tscratch.reshape(reshape_dims)
                              .broadcast(broadcast_dims)
                              .generate(gather_nd_generator)
                              .sum();
-#endif
+#endif  // INTEL_MKL && ENABLE_MKL
 
     // error_loc() returns -1 if there's no out-of-bounds index,
     // otherwise it returns the location of an OOB index in Tindices.
diff --git a/tensorflow/core/kernels/matmul_op.cc b/tensorflow/core/kernels/matmul_op.cc
index 79967aab381..4ad390a4116 100644
--- a/tensorflow/core/kernels/matmul_op.cc
+++ b/tensorflow/core/kernels/matmul_op.cc
@@ -578,7 +578,7 @@ struct MatMulFunctor<SYCLDevice, T> {
                               .Label("cublas"),                    \
                           MatMulOp<GPUDevice, T, true /* cublas */>)
 
-#if defined(INTEL_MKL)
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
 
 // MKL does not support half, bfloat16 and int32 types for
 // matrix-multiplication, so register the kernel to use default Eigen based
@@ -606,9 +606,9 @@ TF_CALL_double(REGISTER_CPU);
 TF_CALL_complex64(REGISTER_CPU_EIGEN);
 TF_CALL_complex128(REGISTER_CPU_EIGEN);
 TF_CALL_double(REGISTER_CPU_EIGEN);
-#endif
+#endif  // INTEL_MKL_DNN_ONLY
 
-#else  // INTEL MKL
+#else   // INTEL_MKL && ENABLE_MKL
 TF_CALL_float(REGISTER_CPU);
 TF_CALL_double(REGISTER_CPU);
 TF_CALL_half(REGISTER_CPU);
@@ -616,7 +616,7 @@ TF_CALL_bfloat16(REGISTER_CPU);
 TF_CALL_int32(REGISTER_CPU);
 TF_CALL_complex64(REGISTER_CPU);
 TF_CALL_complex128(REGISTER_CPU);
-#endif
+#endif  // INTEL_MKL && ENABLE_MKL
 
 #if GOOGLE_CUDA
 TF_CALL_float(REGISTER_GPU);
diff --git a/tensorflow/core/kernels/mkl_batch_matmul_op.cc b/tensorflow/core/kernels/mkl_batch_matmul_op.cc
index 0841395dc38..bc135de11e0 100644
--- a/tensorflow/core/kernels/mkl_batch_matmul_op.cc
+++ b/tensorflow/core/kernels/mkl_batch_matmul_op.cc
@@ -223,10 +223,12 @@ class BatchMatMulMkl : public OpKernel {
       Name("BatchMatMul").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
       BatchMatMulMkl<CPUDevice, TYPE>)
 
+#ifdef ENABLE_MKL
 TF_CALL_float(REGISTER_BATCH_MATMUL_MKL);
 TF_CALL_double(REGISTER_BATCH_MATMUL_MKL);
 TF_CALL_complex64(REGISTER_BATCH_MATMUL_MKL);
 TF_CALL_complex128(REGISTER_BATCH_MATMUL_MKL);
+#endif  // ENABLE_MKL
 
 }  // end namespace tensorflow
 #endif
diff --git a/tensorflow/core/kernels/mkl_matmul_op.cc b/tensorflow/core/kernels/mkl_matmul_op.cc
index 077d62ce325..f4788f48519 100644
--- a/tensorflow/core/kernels/mkl_matmul_op.cc
+++ b/tensorflow/core/kernels/mkl_matmul_op.cc
@@ -217,7 +217,7 @@ class MklMatMulOp : public OpKernel {
                 reinterpret_cast<const MKL_Complex16*>(b), ldb, &beta,
                 reinterpret_cast<MKL_Complex16*>(c), ldc);
   }
-#endif
+#endif  // !INTEL_MKL_DNN_ONLY
 };
 
 #define REGISTER_CPU(T)                                         \
@@ -225,6 +225,7 @@ class MklMatMulOp : public OpKernel {
       Name("MatMul").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
       MklMatMulOp<CPUDevice, T, false /* cublas, ignored for CPU */>);
 
+#ifdef ENABLE_MKL
 // TODO(inteltf) Consider template specialization when adding/removing
 // additional types
 TF_CALL_float(REGISTER_CPU);
@@ -233,7 +234,8 @@ TF_CALL_float(REGISTER_CPU);
 TF_CALL_double(REGISTER_CPU);
 TF_CALL_complex64(REGISTER_CPU);
 TF_CALL_complex128(REGISTER_CPU);
-#endif
+#endif  // !INTEL_MKL_DNN_ONLY
+#endif  // ENABLE_MKL
 
 }  // namespace tensorflow
 #endif  // INTEL_MKL
diff --git a/tensorflow/core/kernels/slice_op.cc b/tensorflow/core/kernels/slice_op.cc
index 77594479cb1..97f77e45b64 100644
--- a/tensorflow/core/kernels/slice_op.cc
+++ b/tensorflow/core/kernels/slice_op.cc
@@ -411,7 +411,7 @@ class MklSliceOp : public OpKernel {
         context->input(0).tensor<T, NDIM>(), indices, sizes);
   }
 };
-#endif
+#endif  // INTEL_MKL
 
 // Forward declarations of the functor specializations for declared in the
 // sharded source files.
@@ -440,19 +440,7 @@ TF_CALL_ALL_TYPES(DECLARE_FOR_N);
 #undef DECLARE_CPU_SPEC
 }  // namespace functor
 
-#ifndef INTEL_MKL
-#define REGISTER_SLICE(type)                             \
-  REGISTER_KERNEL_BUILDER(Name("Slice")                  \
-                              .Device(DEVICE_CPU)        \
-                              .TypeConstraint<type>("T") \
-                              .HostMemory("begin")       \
-                              .HostMemory("size"),       \
-                          SliceOp<CPUDevice, type>)
-
-TF_CALL_POD_STRING_TYPES(REGISTER_SLICE);
-TF_CALL_QUANTIZED_TYPES(REGISTER_SLICE);
-#undef REGISTER_SLICE
-#else
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
 #define REGISTER_SLICE(type)                             \
   REGISTER_KERNEL_BUILDER(Name("Slice")                  \
                               .Device(DEVICE_CPU)        \
@@ -460,11 +448,19 @@ TF_CALL_QUANTIZED_TYPES(REGISTER_SLICE);
                               .HostMemory("begin")       \
                               .HostMemory("size"),       \
                           MklSliceOp<CPUDevice, type>)
+#else
+#define REGISTER_SLICE(type)                             \
+  REGISTER_KERNEL_BUILDER(Name("Slice")                  \
+                              .Device(DEVICE_CPU)        \
+                              .TypeConstraint<type>("T") \
+                              .HostMemory("begin")       \
+                              .HostMemory("size"),       \
+                          SliceOp<CPUDevice, type>)
+#endif  // INTEL_MKL && ENABLE_MKL
 
 TF_CALL_POD_STRING_TYPES(REGISTER_SLICE);
 TF_CALL_QUANTIZED_TYPES(REGISTER_SLICE);
 #undef REGISTER_SLICE
-#endif  // INTEL_MKL
 
 #if GOOGLE_CUDA
 // Forward declarations of the functor specializations for GPU.
diff --git a/tensorflow/core/kernels/transpose_op.cc b/tensorflow/core/kernels/transpose_op.cc
index 0f0f65c5a37..48e392c0707 100644
--- a/tensorflow/core/kernels/transpose_op.cc
+++ b/tensorflow/core/kernels/transpose_op.cc
@@ -218,7 +218,7 @@ Status ConjugateTransposeCpuOp::DoTranspose(OpKernelContext* ctx,
                                             perm, out);
 }
 
-#if defined(INTEL_MKL)
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
 #define REGISTER(T)                                   \
   REGISTER_KERNEL_BUILDER(Name("Transpose")           \
                               .Device(DEVICE_CPU)     \
@@ -230,11 +230,8 @@ Status ConjugateTransposeCpuOp::DoTranspose(OpKernelContext* ctx,
                               .TypeConstraint<T>("T") \
                               .HostMemory("perm"),    \
                           MklConjugateTransposeCpuOp);
-TF_CALL_ALL_TYPES(REGISTER);
-#undef REGISTER
-
-#else  // INTEL_MKL
 
+#else  // INTEL_MKL && ENABLE_MKL
 #define REGISTER(T)                                   \
   REGISTER_KERNEL_BUILDER(Name("Transpose")           \
                               .Device(DEVICE_CPU)     \
@@ -246,9 +243,10 @@ TF_CALL_ALL_TYPES(REGISTER);
                               .TypeConstraint<T>("T") \
                               .HostMemory("perm"),    \
                           ConjugateTransposeCpuOp);
+#endif  // INTEL_MKL && ENABLE_MKL
+
 TF_CALL_ALL_TYPES(REGISTER)
 #undef REGISTER
-#endif  // INTEL_MKL
 
 #if GOOGLE_CUDA
 Status TransposeGpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in,
diff --git a/tensorflow/core/util/port.cc b/tensorflow/core/util/port.cc
index c081ceae57c..e01058dff6c 100644
--- a/tensorflow/core/util/port.cc
+++ b/tensorflow/core/util/port.cc
@@ -38,10 +38,10 @@ bool CudaSupportsHalfMatMulAndConv() {
 }
 
 bool IsMklEnabled() {
-#ifdef INTEL_MKL
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
   return true;
 #else
   return false;
-#endif
+#endif  // INTEL_MKL && ENABLE_MKL
 }
 }  // end namespace tensorflow
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index 7ddaf7806ea..d6c75d675c2 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -22,6 +22,7 @@ load(
 )
 load(
     "//third_party/mkl:build_defs.bzl",
+    "if_enable_mkl",
     "if_mkl",
     "if_mkl_lnx_x64",
     "if_mkl_ml",
@@ -237,6 +238,7 @@ def tf_copts(android_optimization_level_override = "-O2", is_external = False):
         if_tensorrt(["-DGOOGLE_TENSORRT=1"]) +
         if_mkl(["-DINTEL_MKL=1", "-DEIGEN_USE_VML"]) +
         if_mkl_open_source_only(["-DINTEL_MKL_DNN_ONLY"]) +
+        if_enable_mkl(["-DENABLE_MKL"]) +
         if_ngraph(["-DINTEL_NGRAPH=1"]) +
         if_mkl_lnx_x64(["-fopenmp"]) +
         if_android_arm(["-mfpu=neon"]) +
@@ -1082,6 +1084,7 @@ def tf_cuda_library(deps = None, cuda_deps = None, copts = tf_copts(), **kwargs)
         ]),
         copts = (copts + if_cuda(["-DGOOGLE_CUDA=1"]) + if_mkl(["-DINTEL_MKL=1"]) +
                  if_mkl_open_source_only(["-DINTEL_MKL_DNN_ONLY"]) +
+                 if_enable_mkl(["-DENABLE_MKL"]) +
                  if_tensorrt(["-DGOOGLE_TENSORRT=1"])),
         **kwargs
     )
diff --git a/third_party/mkl/BUILD b/third_party/mkl/BUILD
index efff7fd51b1..15a3e5cfa7e 100644
--- a/third_party/mkl/BUILD
+++ b/third_party/mkl/BUILD
@@ -1,26 +1,26 @@
 licenses(["notice"])  # 3-Clause BSD
 
 config_setting(
-    name = "using_mkl",
+    name = "build_with_mkl",
     define_values = {
-        "using_mkl": "true",
+        "build_with_mkl": "true",
     },
     visibility = ["//visibility:public"],
 )
 
 config_setting(
-    name = "using_mkl_ml_only",
+    name = "build_with_mkl_ml_only",
     define_values = {
-        "using_mkl": "true",
-        "using_mkl_ml_only": "true",
+        "build_with_mkl": "true",
+        "build_with_mkl_ml_only": "true",
     },
     visibility = ["//visibility:public"],
 )
 
 config_setting(
-    name = "using_mkl_lnx_x64",
+    name = "build_with_mkl_lnx_x64",
     define_values = {
-        "using_mkl": "true",
+        "build_with_mkl": "true",
     },
     values = {
         "cpu": "k8",
@@ -28,6 +28,15 @@ config_setting(
     visibility = ["//visibility:public"],
 )
 
+config_setting(
+    name = "enable_mkl",
+    define_values = {
+        "enable_mkl": "true",
+        "build_with_mkl": "true",
+    },
+    visibility = ["//visibility:public"],
+)
+
 load(
     "//third_party/mkl:build_defs.bzl",
     "if_mkl",
diff --git a/third_party/mkl/build_defs.bzl b/third_party/mkl/build_defs.bzl
index b645c0fc5c7..bb798e715ab 100644
--- a/third_party/mkl/build_defs.bzl
+++ b/third_party/mkl/build_defs.bzl
@@ -1,9 +1,11 @@
 # -*- Python -*-
 """Skylark macros for MKL.
-if_mkl is a conditional to check if MKL is enabled or not.
-if_mkl_ml is a conditional to check if MKL-ML is enabled.
+
+if_mkl is a conditional to check if we are building with MKL.
+if_mkl_ml is a conditional to check if we are building with MKL-ML.
 if_mkl_ml_only is a conditional to check for MKL-ML-only (no MKL-DNN) mode.
 if_mkl_lnx_x64 is a conditional to check for MKL
+if_enable_mkl is a conditional to check if building with MKL and MKL is enabled.
 
 mkl_repository is a repository rule for creating MKL repository rule that can
 be pointed to either a local folder, or download it from the internet.
@@ -24,7 +26,7 @@ def if_mkl(if_true, if_false = []):
       a select evaluating to either if_true or if_false as appropriate.
     """
     return select({
-        str(Label("//third_party/mkl:using_mkl")): if_true,
+        str(Label("//third_party/mkl:build_with_mkl")): if_true,
         "//conditions:default": if_false,
     })
 
@@ -40,8 +42,8 @@ def if_mkl_ml(if_true, if_false = []):
       a select evaluating to either if_true or if_false as appropriate.
     """
     return select({
-        str(Label("//third_party/mkl_dnn:using_mkl_dnn_only")): if_false,
-        str(Label("//third_party/mkl:using_mkl")): if_true,
+        str(Label("//third_party/mkl_dnn:build_with_mkl_dnn_only")): if_false,
+        str(Label("//third_party/mkl:build_with_mkl")): if_true,
         "//conditions:default": if_false,
     })
 
@@ -56,12 +58,12 @@ def if_mkl_ml_only(if_true, if_false = []):
       a select evaluating to either if_true or if_false as appropriate.
     """
     return select({
-        str(Label("//third_party/mkl:using_mkl_ml_only")): if_true,
+        str(Label("//third_party/mkl:build_with_mkl_ml_only")): if_true,
         "//conditions:default": if_false,
     })
 
 def if_mkl_lnx_x64(if_true, if_false = []):
-    """Shorthand to select() on if MKL is on and the target is Linux x86-64.
+    """Shorthand to select() if building with MKL and the target is Linux x86-64.
 
     Args:
       if_true: expression to evaluate if building with MKL is enabled and the
@@ -73,7 +75,24 @@ def if_mkl_lnx_x64(if_true, if_false = []):
       a select evaluating to either if_true or if_false as appropriate.
     """
     return select({
-        str(Label("//third_party/mkl:using_mkl_lnx_x64")): if_true,
+        str(Label("//third_party/mkl:build_with_mkl_lnx_x64")): if_true,
+        "//conditions:default": if_false,
+    })
+
+def if_enable_mkl(if_true, if_false = []):
+    """Shorthand to select() if we are building with MKL and MKL is enabled.
+
+    This is only effective when built with MKL.
+
+    Args:
+      if_true: expression to evaluate if building with MKL and MKL is enabled
+      if_false: expression to evaluate if building without MKL or MKL is not enabled.
+
+    Returns:
+      A select evaluating to either if_true or if_false as appropriate.
+    """
+    return select({
+        "//third_party/mkl:enable_mkl": if_true,
         "//conditions:default": if_false,
     })
 
@@ -87,9 +106,9 @@ def mkl_deps():
       inclusion in the deps attribute of rules.
     """
     return select({
-        str(Label("//third_party/mkl_dnn:using_mkl_dnn_only")): ["@mkl_dnn"],
-        str(Label("//third_party/mkl:using_mkl_ml_only")): ["//third_party/mkl:intel_binary_blob"],
-        str(Label("//third_party/mkl:using_mkl")): [
+        str(Label("//third_party/mkl_dnn:build_with_mkl_dnn_only")): ["@mkl_dnn"],
+        str(Label("//third_party/mkl:build_with_mkl_ml_only")): ["//third_party/mkl:intel_binary_blob"],
+        str(Label("//third_party/mkl:build_with_mkl")): [
             "//third_party/mkl:intel_binary_blob",
             "@mkl_dnn",
         ],
diff --git a/third_party/mkl_dnn/BUILD b/third_party/mkl_dnn/BUILD
index 3e567fa9fca..58ecda55e6e 100644
--- a/third_party/mkl_dnn/BUILD
+++ b/third_party/mkl_dnn/BUILD
@@ -3,10 +3,10 @@ licenses(["notice"])
 exports_files(["LICENSE"])
 
 config_setting(
-    name = "using_mkl_dnn_only",
+    name = "build_with_mkl_dnn_only",
     define_values = {
-        "using_mkl": "true",
-        "using_mkl_dnn_only": "true",
+        "build_with_mkl": "true",
+        "build_with_mkl_dnn_only": "true",
     },
     visibility = ["//visibility:public"],
 )
diff --git a/third_party/mkl_dnn/build_defs.bzl b/third_party/mkl_dnn/build_defs.bzl
index 7ce2a7d9b03..6388f31971c 100644
--- a/third_party/mkl_dnn/build_defs.bzl
+++ b/third_party/mkl_dnn/build_defs.bzl
@@ -8,6 +8,6 @@ def if_mkl_open_source_only(if_true, if_false = []):
 
     """
     return select({
-        str(Label("//third_party/mkl_dnn:using_mkl_dnn_only")): if_true,
+        str(Label("//third_party/mkl_dnn:build_with_mkl_dnn_only")): if_true,
         "//conditions:default": if_false,
     })
diff --git a/tools/bazel.rc b/tools/bazel.rc
index ccf62629d1e..6747c7e7951 100644
--- a/tools/bazel.rc
+++ b/tools/bazel.rc
@@ -24,12 +24,13 @@ build --define framework_shared_object=true
 # Please note that MKL on MacOS or windows is still not supported.
 # If you would like to use a local MKL instead of downloading, please set the
 # environment variable "TF_MKL_ROOT" every time before build.
-build:mkl --define=using_mkl=true
+build:mkl --define=build_with_mkl=true --define=enable_mkl=true
 build:mkl -c opt
 
 # This config option is used to enable MKL-DNN open source library only,
 # without depending on MKL binary version.
-build:mkl_open_source_only --define=using_mkl_dnn_only=true
+build:mkl_open_source_only --define=build_with_mkl_dnn_only=true 
+build:mkl_open_source_only --define=build_with_mkl=true --define=enable_mkl=true
 
 build:download_clang --crosstool_top=@local_config_download_clang//:toolchain
 build:download_clang --define=using_clang=true