diff --git a/tensorflow/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD
index 1e91d0a6b2c..0d68e735c9a 100644
--- a/tensorflow/lite/kernels/internal/BUILD
+++ b/tensorflow/lite/kernels/internal/BUILD
@@ -212,6 +212,7 @@ cc_library(
     ],
     copts = tflite_copts(),
     deps = [
+        ":cpu_check",
         ":quantization_util",
         ":strided_slice_logic",
         ":types",
@@ -254,6 +255,7 @@ cc_library(
     ],
     copts = tflite_copts(),
     deps = [
+        ":cpu_check",
         ":optimized_base",
         ":quantization_util",
         ":strided_slice_logic",
@@ -505,10 +507,10 @@ cc_library(
         ":types",
         "//tensorflow/lite/c:c_api_internal",
         "//tensorflow/lite/kernels:activation_functor",
+        "//tensorflow/lite/kernels:cpu_backend_context",
         "//tensorflow/lite/kernels:op_macros",
         "@arm_neon_2_x86_sse",
         "@gemmlowp//:fixedpoint",
-        "@gemmlowp//:profiler",
     ],
 )
 
@@ -560,12 +562,12 @@ cc_library(
     ],
     copts = NEON_FLAGS_IF_APPLICABLE,
     deps = [
-        "@com_google_absl//absl/base:core_headers",
+        ":cpu_check",
         "//tensorflow/lite/c:c_api_internal",
         "@arm_neon_2_x86_sse",
+        "//tensorflow/lite/kernels:cpu_backend_context",
         "//tensorflow/lite/kernels:op_macros",
         "@gemmlowp//:fixedpoint",
-        "@gemmlowp//:profiler",
     ] + select({
         ":aarch64": [
             ":neon_tensor_utils",
@@ -650,6 +652,7 @@ cc_test(
     name = "depthwiseconv_float_test",
     srcs = ["depthwiseconv_float_test.cc"],
     deps = [
+        ":cpu_check",
         ":optimized_base",
         ":reference_base",
         ":test_util",
@@ -811,6 +814,7 @@ cc_library(
         "optimized/cpu_check.h",
     ],
     deps = [
+        "//tensorflow/lite/kernels:cpu_backend_context",
     ] + select(
         {
             "//tensorflow:android": [
diff --git a/tensorflow/lite/kernels/internal/depthwiseconv_float_test.cc b/tensorflow/lite/kernels/internal/depthwiseconv_float_test.cc
index c7994c6838a..ba7a2d21221 100644
--- a/tensorflow/lite/kernels/internal/depthwiseconv_float_test.cc
+++ b/tensorflow/lite/kernels/internal/depthwiseconv_float_test.cc
@@ -22,6 +22,7 @@ limitations under the License.
 #include "tensorflow/lite/kernels/internal/types.h"
 
 #define ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
+#include "tensorflow/lite/kernels/internal/optimized/cpu_check.h"
 #include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_float.h"
 #include "tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h"
 
@@ -42,7 +43,8 @@ void TestOneDepthwiseConv(
                                reference_output_data.data());
   optimized_ops::DepthwiseConvImpl(
       params, input_shape, input_data, filter_shape, filter_data, bias_shape,
-      bias_data, output_shape, output_data.data(), nullptr, /*thread_start=*/0,
+      bias_data, output_shape, output_data.data(), CpuFlags(),
+      /*thread_start=*/0,
       /*thread_end=*/output_shape.Dims(1), /*thread_dim=*/1);
 
   double sum_abs_diff = 0;
diff --git a/tensorflow/lite/kernels/internal/optimized/cpu_check.h b/tensorflow/lite/kernels/internal/optimized/cpu_check.h
index ac4ea7d6dae..c5ad7925399 100644
--- a/tensorflow/lite/kernels/internal/optimized/cpu_check.h
+++ b/tensorflow/lite/kernels/internal/optimized/cpu_check.h
@@ -15,6 +15,8 @@ limitations under the License.
 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_CPU_CHECK_H_
 #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_CPU_CHECK_H_
 
+#include "tensorflow/lite/kernels/cpu_backend_context.h"
+
 namespace tflite {
 
 #ifdef __ANDROID__
@@ -44,6 +46,18 @@ inline bool TestCPUFeatureNeon() { return false; }
 
 #endif
 
+struct CpuFlags {
+  bool neon_dotprod = false;
+};
+
+inline void GetCpuFlags(CpuBackendContext* cpu_backend_context,
+                        CpuFlags* cpu_flags) {
+  ruy::Context* ruy_context = cpu_backend_context->ruy_context();
+  cpu_flags->neon_dotprod =
+      ruy_context != nullptr && (ruy_context->GetRuntimeEnabledPaths() &
+                                 ruy::Path::kNeonDotprod) != ruy::Path::kNone;
+}
+
 }  // namespace tflite
 
 // NEON_OR_PORTABLE(SomeFunc, arcs) calls NeonSomeFunc(args) if Neon is both
diff --git a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_float.h b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_float.h
index 74e8356bc61..f171ddd7825 100644
--- a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_float.h
+++ b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_float.h
@@ -16,8 +16,8 @@ limitations under the License.
 #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_
 
 #include "profiling/instrumentation.h"
-#include "tensorflow/lite/kernels/cpu_backend_context.h"
 #include "tensorflow/lite/kernels/internal/common.h"
+#include "tensorflow/lite/kernels/internal/optimized/cpu_check.h"
 #include "tensorflow/lite/kernels/internal/types.h"
 
 namespace tflite {
@@ -897,7 +897,7 @@ inline void DepthwiseConvInitAccBuffer(int num_output_pixels, int output_depth,
 // For example, assume thread_start = 2, thread_end = 6, and thread_dim = 1, it
 // means that it will calculate DepthwiseConv for output_data[:, 2:5, :, :].
 //
-// The cpu_backend_context may be supplied as a nullptr by some callers. This
+// The cpu_flags is currently unused. This
 // parameter is included so that the signature matches that required by a
 // templated function. Other versions, such as quantized, need this parameter.
 inline void DepthwiseConvImpl(
@@ -905,8 +905,8 @@ inline void DepthwiseConvImpl(
     const float* input_data, const RuntimeShape& filter_shape,
     const float* filter_data, const RuntimeShape& bias_shape,
     const float* bias_data, const RuntimeShape& output_shape,
-    float* output_data, CpuBackendContext* cpu_backend_context,
-    int thread_start, int thread_end, int thread_dim) {
+    float* output_data, const CpuFlags& /* cpu_flags */, int thread_start,
+    int thread_end, int thread_dim) {
   gemmlowp::ScopedProfilingLabel label("DepthwiseConv/float/DepthwiseConvImpl");
 
   const int stride_width = params.stride_width;
@@ -1116,10 +1116,9 @@ inline void DepthwiseConv(
     const float* input_data, const RuntimeShape& filter_shape,
     const float* filter_data, const RuntimeShape& bias_shape,
     const float* bias_data, const RuntimeShape& output_shape,
-    float* output_data, CpuBackendContext* cpu_backend_context) {
+    float* output_data, const CpuFlags& cpu_flags) {
   DepthwiseConvImpl(params, input_shape, input_data, filter_shape, filter_data,
-                    bias_shape, bias_data, output_shape, output_data,
-                    cpu_backend_context,
+                    bias_shape, bias_data, output_shape, output_data, cpu_flags,
                     /*thread_start=*/0,
                     /*thread_end=*/output_shape.Dims(1), /*thread_dim=*/1);
 }
diff --git a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_multithread.h b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_multithread.h
index 05a1476b518..62c6f61ae47 100644
--- a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_multithread.h
+++ b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_multithread.h
@@ -17,6 +17,7 @@ limitations under the License.
 
 #include "tensorflow/lite/kernels/cpu_backend_context.h"
 #include "tensorflow/lite/kernels/cpu_backend_threadpool.h"
+#include "tensorflow/lite/kernels/internal/optimized/cpu_check.h"
 #include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_float.h"
 #include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8.h"
 
@@ -36,8 +37,7 @@ struct DepthwiseConvWorkerTask : cpu_backend_threadpool::Task {
                           const RuntimeShape& filter_shape,
                           const T* filter_data, const RuntimeShape& bias_shape,
                           const TS* bias_data, const RuntimeShape& output_shape,
-                          T* output_data,
-                          CpuBackendContext* cpu_backend_context,
+                          T* output_data, const CpuFlags& cpu_flags,
                           int thread_start, int thread_end, int thread_dim)
       : params_(params),
         input_shape_(input_shape),
@@ -48,7 +48,7 @@ struct DepthwiseConvWorkerTask : cpu_backend_threadpool::Task {
         bias_data_(bias_data),
         output_shape_(output_shape),
         output_data_(output_data),
-        cpu_backend_context_(cpu_backend_context),
+        cpu_flags_(cpu_flags),
         thread_start_(thread_start),
         thread_end_(thread_end),
         thread_dim_(thread_dim) {}
@@ -56,8 +56,8 @@ struct DepthwiseConvWorkerTask : cpu_backend_threadpool::Task {
   void Run() override {
     DepthwiseConvImpl(params_, input_shape_, input_data_, filter_shape_,
                       filter_data_, bias_shape_, bias_data_, output_shape_,
-                      output_data_, cpu_backend_context_, thread_start_,
-                      thread_end_, thread_dim_);
+                      output_data_, cpu_flags_, thread_start_, thread_end_,
+                      thread_dim_);
   }
 
  private:
@@ -70,7 +70,7 @@ struct DepthwiseConvWorkerTask : cpu_backend_threadpool::Task {
   const TS* bias_data_;
   const RuntimeShape& output_shape_;
   T* output_data_;
-  CpuBackendContext* cpu_backend_context_;
+  const CpuFlags& cpu_flags_;
   int thread_start_;
   int thread_end_;
   int thread_dim_;
@@ -143,10 +143,13 @@ inline void DepthwiseConv(const DepthwiseParams& params,
   const int output_batches = output_shape.Dims(0);
   const int output_height = output_shape.Dims(1);
 
+  CpuFlags cpu_flags;
+  GetCpuFlags(cpu_backend_context, &cpu_flags);
+
   if (thread_count == 1) {
     DepthwiseConvImpl(params, input_shape, input_data, filter_shape,
                       filter_data, bias_shape, bias_data, output_shape,
-                      output_data, cpu_backend_context, /*thread_start=*/0,
+                      output_data, cpu_flags, /*thread_start=*/0,
                       /*thread_end=*/output_height, /*thread_dim=*/1);
     return;
   }
@@ -170,8 +173,8 @@ inline void DepthwiseConv(const DepthwiseParams& params,
         thread_start + (thread_dim_size - thread_start) / (thread_count - i);
     tasks.emplace_back(params, input_shape, input_data, filter_shape,
                        filter_data, bias_shape, bias_data, output_shape,
-                       output_data, cpu_backend_context, thread_start,
-                       thread_end, thread_dim);
+                       output_data, cpu_flags, thread_start, thread_end,
+                       thread_dim);
     thread_start = thread_end;
   }
   cpu_backend_threadpool::Execute(tasks.size(), tasks.data(),
diff --git a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8.h b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8.h
index 06be0143454..9f8736e76d0 100644
--- a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8.h
+++ b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8.h
@@ -18,8 +18,8 @@ limitations under the License.
 #include <type_traits>
 
 #include "profiling/instrumentation.h"
-#include "tensorflow/lite/kernels/cpu_backend_context.h"
 #include "tensorflow/lite/kernels/internal/common.h"
+#include "tensorflow/lite/kernels/internal/optimized/cpu_check.h"
 #include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h"
 #include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h"
 #include "tensorflow/lite/kernels/internal/types.h"
@@ -1986,8 +1986,8 @@ inline void DepthwiseConvWithRounding(
     const uint8* input_data, const RuntimeShape& filter_shape,
     const uint8* filter_data, const RuntimeShape& bias_shape,
     const int32* bias_data, const RuntimeShape& output_shape,
-    uint8* output_data, CpuBackendContext* cpu_backend_context,
-    int thread_start, int thread_end, int thread_dim) {
+    uint8* output_data, const CpuFlags& cpu_flags, int thread_start,
+    int thread_end, int thread_dim) {
   gemmlowp::ScopedProfilingLabel label("DepthwiseConv/8bit");
   const int depth_multiplier = params.depth_multiplier;
   const int32 output_activation_min = params.quantized_activation_min;
@@ -2009,12 +2009,7 @@ inline void DepthwiseConvWithRounding(
 // Jetson TX-2. This compiler does not support the offsetof() macro.
 #if defined(__aarch64__) && !defined(GOOGLE_L4T)
   // Dispatch to dot-product 3x3 kernels when supported.
-
-  ruy::Context* ruy_context = cpu_backend_context->ruy_context();
-  const bool has_dot_product_instructions =
-      ruy_context != nullptr && (ruy_context->GetRuntimeEnabledPaths() &
-                                 ruy::Path::kNeonDotprod) != ruy::Path::kNone;
-  if (has_dot_product_instructions) {
+  if (cpu_flags.neon_dotprod) {
     using optimized_ops::depthwise_conv::DotProduct3x3KernelType;
     DotProduct3x3KernelType kernel_type =
         optimized_ops::depthwise_conv::CategorizeDotProductKernel(
@@ -2067,12 +2062,12 @@ inline void DepthwiseConvImpl(
     const uint8* input_data, const RuntimeShape& filter_shape,
     const uint8* filter_data, const RuntimeShape& bias_shape,
     const int32* bias_data, const RuntimeShape& output_shape,
-    uint8* output_data, CpuBackendContext* cpu_backend_context,
-    int thread_start, int thread_end, int thread_dim) {
+    uint8* output_data, const CpuFlags& cpu_flags, int thread_start,
+    int thread_end, int thread_dim) {
   return DepthwiseConvWithRounding<DepthwiseConvOutputRounding::kUpward>(
       params, input_shape, input_data, filter_shape, filter_data, bias_shape,
-      bias_data, output_shape, output_data, cpu_backend_context, thread_start,
-      thread_end, thread_dim);
+      bias_data, output_shape, output_data, cpu_flags, thread_start, thread_end,
+      thread_dim);
 }
 
 void DepthwiseConv(const DepthwiseParams& params,
@@ -2080,7 +2075,7 @@ void DepthwiseConv(const DepthwiseParams& params,
                    const RuntimeShape& filter_shape, const uint8* filter_data,
                    const RuntimeShape& bias_shape, const int32* bias_data,
                    const RuntimeShape& output_shape, uint8* output_data,
-                   CpuBackendContext* cpu_backend_context);
+                   const CpuFlags& cpu_flags);
 
 }  // namespace optimized_ops
 }  // namespace tflite
diff --git a/tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h
index f1482f71c4c..08fee8caf14 100644
--- a/tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h
+++ b/tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h
@@ -20,6 +20,7 @@ limitations under the License.
 
 #include "public/gemmlowp.h"
 #include "tensorflow/lite/kernels/internal/common.h"
+#include "tensorflow/lite/kernels/internal/optimized/cpu_check.h"
 #include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_multithread.h"
 #include "tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv.h"
 #include "tensorflow/lite/kernels/internal/optimized/integer_ops/fully_connected.h"
@@ -165,7 +166,7 @@ inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
   DepthwiseConvImpl(op_params, DimsToShape(input_dims), input_data,
                     DimsToShape(filter_dims), filter_data,
                     DimsToShape(bias_dims), bias_data, output_shape,
-                    output_data, nullptr, /*thread_start=*/0,
+                    output_data, CpuFlags(), /*thread_start=*/0,
                     /*thread_end=*/output_height, /*thread_dim=*/1);
 }
 
@@ -598,7 +599,8 @@ inline void DepthwiseConv(
     const float* bias_data, const RuntimeShape& output_shape,
     float* output_data) {
   DepthwiseConvImpl(params, input_shape, input_data, filter_shape, filter_data,
-                    bias_shape, bias_data, output_shape, output_data, nullptr,
+                    bias_shape, bias_data, output_shape, output_data,
+                    CpuFlags(),
                     /*thread_start=*/0,
                     /*thread_end=*/output_shape.Dims(1), /*thread_dim=*/1);
 }