diff --git a/tensorflow/lite/kernels/add.cc b/tensorflow/lite/kernels/add.cc
index f224cb33eb0..17214d53ea1 100644
--- a/tensorflow/lite/kernels/add.cc
+++ b/tensorflow/lite/kernels/add.cc
@@ -177,8 +177,6 @@ void EvalAdd(TfLiteContext* context, TfLiteNode* node, TfLiteAddParams* params,
              const OpData* data, const TfLiteTensor* input1,
              const TfLiteTensor* input2, TfLiteTensor* output) {
   tflite::ArithmeticParams op_params;
-  // requires_flat_size_broadcast is used for BroadcastAdd4DSlow.
-  const bool requires_flat_size_broadcast = !HaveSameShapes(input1, input2);
   const bool need_broadcast = optimized_ops::ProcessBroadcastShapes(
       GetTensorShape(input1), GetTensorShape(input2), &op_params);
 #define TF_LITE_ADD(type, opname, data_type)                             \
@@ -193,13 +191,13 @@ void EvalAdd(TfLiteContext* context, TfLiteNode* node, TfLiteAddParams* params,
                GetTensorData<data_type>(output))
   if (output->type == kTfLiteInt32) {
     if (kernel_type == kReference) {
-      if (requires_flat_size_broadcast) {
+      if (need_broadcast) {
         TF_LITE_ADD(reference_ops, BroadcastAdd4DSlow, int32_t);
       } else {
         TF_LITE_ADD(reference_ops, Add, int32_t);
       }
     } else {
-      if (requires_flat_size_broadcast) {
+      if (need_broadcast) {
         TF_LITE_ADD(optimized_ops, BroadcastAdd4DSlow, int32_t);
       } else {
         TF_LITE_ADD(optimized_ops, Add, int32_t);
@@ -207,7 +205,7 @@ void EvalAdd(TfLiteContext* context, TfLiteNode* node, TfLiteAddParams* params,
     }
   } else if (output->type == kTfLiteFloat32) {
     if (kernel_type == kReference) {
-      if (requires_flat_size_broadcast) {
+      if (need_broadcast) {
         TF_LITE_ADD(reference_ops, BroadcastAdd4DSlow, float);
       } else {
         TF_LITE_ADD(reference_ops, Add, float);
@@ -215,8 +213,6 @@ void EvalAdd(TfLiteContext* context, TfLiteNode* node, TfLiteAddParams* params,
     } else {
       if (need_broadcast) {
         TF_LITE_ADD(optimized_ops, BroadcastAddFivefold, float);
-      } else if (requires_flat_size_broadcast) {
-        TF_LITE_ADD(optimized_ops, BroadcastAdd4DSlow, float);
       } else {
         TF_LITE_ADD(optimized_ops, Add, float);
       }
diff --git a/tensorflow/lite/kernels/add_test.cc b/tensorflow/lite/kernels/add_test.cc
index 9dd7df147c8..ef97b7785e1 100644
--- a/tensorflow/lite/kernels/add_test.cc
+++ b/tensorflow/lite/kernels/add_test.cc
@@ -190,6 +190,24 @@ TEST(IntegerAddOpModel, WithBroadcast) {
   }
 }
 
+TEST(IntegerAddOpModel, Int32MultiDimBroadcast) {
+  IntegerAddOpModel m({TensorType_INT32, {1, 2}}, {TensorType_INT32, {2, 1}},
+                      {TensorType_INT32, {}}, ActivationFunctionType_NONE);
+  m.PopulateTensor<int32_t>(m.input1(), {3, 5});
+  m.PopulateTensor<int32_t>(m.input2(), {1, 4});
+  m.Invoke();
+  EXPECT_THAT(m.GetOutput(), ElementsAreArray({4, 6, 7, 9}));
+}
+
+TEST(IntegerAddOpModel, Float32MultiDimBroadcast) {
+  FloatAddOpModel m({TensorType_FLOAT32, {1, 2}}, {TensorType_FLOAT32, {2, 1}},
+                    {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
+  m.PopulateTensor<float>(m.input1(), {3, 5});
+  m.PopulateTensor<float>(m.input2(), {1, 4});
+  m.Invoke();
+  EXPECT_THAT(m.GetOutput(), ElementsAreArray({4, 6, 7, 9}));
+}
+
 template <TensorType tensor_type, typename integer_dtype>
 void QuantizedTestsNoActivation() {
   float kQuantizedTolerance = GetTolerance(-1.0, 1.0);
diff --git a/tensorflow/lite/kernels/internal/optimized/integer_ops/add.h b/tensorflow/lite/kernels/internal/optimized/integer_ops/add.h
index 1abf89a8e38..253944ca3f1 100644
--- a/tensorflow/lite/kernels/internal/optimized/integer_ops/add.h
+++ b/tensorflow/lite/kernels/internal/optimized/integer_ops/add.h
@@ -221,7 +221,7 @@ inline void Add(const ArithmeticParams& params,
                    params.quantized_activation_max);
   gemmlowp::ScopedProfilingLabel label("AddInt8/8bit");
   const int flat_size =
-      MatchingFlatSize(input1_shape, input2_shape, output_shape);
+      MatchingElementsSize(input1_shape, input2_shape, output_shape);
 
   TFLITE_DCHECK_GT(params.input1_offset, -256);
   TFLITE_DCHECK_GT(params.input2_offset, -256);
diff --git a/tensorflow/lite/kernels/internal/optimized/integer_ops/mul.h b/tensorflow/lite/kernels/internal/optimized/integer_ops/mul.h
index 08b8da09915..74b9d4b6a9e 100644
--- a/tensorflow/lite/kernels/internal/optimized/integer_ops/mul.h
+++ b/tensorflow/lite/kernels/internal/optimized/integer_ops/mul.h
@@ -175,7 +175,7 @@ inline void Mul(const ArithmeticParams& params,
                    params.quantized_activation_max);
   gemmlowp::ScopedProfilingLabel label("MulInt8/8bit");
   const int flat_size =
-      MatchingFlatSize(input1_shape, input2_shape, output_shape);
+      MatchingElementsSize(input1_shape, input2_shape, output_shape);
 
   MulElementwise(flat_size, params, input1_data, input2_data, output_data);
 }
diff --git a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h
index 787cc4cd68a..ccb66ce819f 100644
--- a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h
@@ -1543,7 +1543,7 @@ inline void Add(const ArithmeticParams& params,
                 const RuntimeShape& output_shape, float* output_data) {
   gemmlowp::ScopedProfilingLabel label("Add");
   const int flat_size =
-      MatchingFlatSize(input1_shape, input2_shape, output_shape);
+      MatchingElementsSize(input1_shape, input2_shape, output_shape);
   AddElementwise(flat_size, params, input1_data, input2_data, output_data);
 }
 
@@ -1782,7 +1782,7 @@ inline void Add(const ArithmeticParams& params,
                    params.quantized_activation_max);
   gemmlowp::ScopedProfilingLabel label("Add/8bit");
   const int flat_size =
-      MatchingFlatSize(input1_shape, input2_shape, output_shape);
+      MatchingElementsSize(input1_shape, input2_shape, output_shape);
 
   TFLITE_DCHECK_GT(params.input1_offset, -256);
   TFLITE_DCHECK_GT(params.input2_offset, -256);
@@ -1801,7 +1801,7 @@ inline void Add(const ArithmeticParams& params,
 
   const int input1_shift = params.input1_shift;
   const int flat_size =
-      MatchingFlatSize(output_shape, input1_shape, input2_shape);
+      MatchingElementsSize(input1_shape, input2_shape, output_shape);
   const int16 output_activation_min = params.quantized_activation_min;
   const int16 output_activation_max = params.quantized_activation_max;
 
@@ -1846,8 +1846,10 @@ inline void Add(const ArithmeticParams& params,
     auto scalar = input1_data[0];
     output_map.array() = scalar + input2_map.array();
   } else {
-    // Should not come here.
-    TFLITE_DCHECK(false);
+    reference_ops::BroadcastAdd4DSlow(params, input1_shape, input1_data,
+                                      input2_shape, input2_data, output_shape,
+                                      output_data);
+    return;
   }
   output_map = output_map.cwiseMax(params.quantized_activation_min);
   output_map = output_map.cwiseMin(params.quantized_activation_max);
@@ -2097,7 +2099,7 @@ inline void Mul(const ArithmeticParams& params,
   gemmlowp::ScopedProfilingLabel label("Mul");
 
   const int flat_size =
-      MatchingFlatSize(input1_shape, input2_shape, output_shape);
+      MatchingElementsSize(input1_shape, input2_shape, output_shape);
   MulElementwise(flat_size, params, input1_data, input2_data, output_data);
 }
 
@@ -2108,7 +2110,7 @@ inline void Mul(const ArithmeticParams& params,
   gemmlowp::ScopedProfilingLabel label("Mul/int32/activation");
 
   const int flat_size =
-      MatchingFlatSize(input1_shape, input2_shape, output_shape);
+      MatchingElementsSize(input1_shape, input2_shape, output_shape);
   const int32 output_activation_min = params.quantized_activation_min;
   const int32 output_activation_max = params.quantized_activation_max;
   for (int i = 0; i < flat_size; ++i) {
@@ -2139,8 +2141,9 @@ inline void MulNoActivation(const ArithmeticParams& params,
     auto scalar = input1_data[0];
     output_map.array() = scalar * input2_map.array();
   } else {
-    // Should not come here.
-    TFLITE_DCHECK(false);
+    reference_ops::BroadcastMul4DSlow(params, input1_shape, input1_data,
+                                      input2_shape, input2_data, output_shape,
+                                      output_data);
   }
 }
 
@@ -2153,7 +2156,7 @@ inline void Mul(const ArithmeticParams& params,
   // properly optimized version.
 
   const int flat_size =
-      MatchingFlatSize(input1_shape, input2_shape, output_shape);
+      MatchingElementsSize(input1_shape, input2_shape, output_shape);
 
   for (int i = 0; i < flat_size; i++) {
     // F0 uses 0 integer bits, range [-1, 1].
@@ -2178,7 +2181,7 @@ inline void Mul(const ArithmeticParams& params,
   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
 
   const int flat_size =
-      MatchingFlatSize(input1_shape, input2_shape, output_shape);
+      MatchingElementsSize(input1_shape, input2_shape, output_shape);
 
   for (int i = 0; i < flat_size; i++) {
     // F0 uses 0 integer bits, range [-1, 1].
@@ -2342,6 +2345,8 @@ inline void MulSimpleBroadcast(int size, const ArithmeticParams& params,
 }
 
 // Broadcast mul that can often be used for inner loop of broadcast Mul.
+// This function will handle scalar_value (LHS) * vector_values (RHS).
+// Since it's a float function, input params does not matter here.
 inline void MulSimpleBroadcast(int size, const ArithmeticParams& params,
                                const float broadcast_value,
                                const float* input2_data, float* output_data) {
@@ -2380,7 +2385,7 @@ inline void Mul(const ArithmeticParams& params,
                    params.quantized_activation_max);
   gemmlowp::ScopedProfilingLabel label("Mul/8bit");
   const int flat_size =
-      MatchingFlatSize(input1_shape, input2_shape, output_shape);
+      MatchingElementsSize(input1_shape, input2_shape, output_shape);
 
   MulElementwise(flat_size, params, input1_data, input2_data, output_data);
 }
@@ -2509,6 +2514,8 @@ inline void BroadcastMulFivefold(const ArithmeticParams& params,
       for (int i1 = 0; i1 < y1; ++i1) {
         input2_data_ptr = input2_data_reset;
         for (int i2 = 0; i2 < y2; ++i2) {
+          // The input may be switched here, but the common parameters here
+          // do not matter as they will not influence the float math execution.
           MulSimpleBroadcast(y3, params, *input1_data_ptr, input2_data_ptr,
                              output_data_ptr);
           input2_data_ptr += y3;
@@ -2652,7 +2659,7 @@ inline void SubNonBroadcast(const ArithmeticParams& params,
                             float* output_data) {
   gemmlowp::ScopedProfilingLabel label("SubNonBroadcast");
   const int flat_size =
-      MatchingFlatSize(input1_shape, input2_shape, output_shape);
+      MatchingElementsSize(input1_shape, input2_shape, output_shape);
   for (int i = 0; i < flat_size; ++i) {
     output_data[i] = ActivationFunctionWithMinMax(
         input1_data[i] - input2_data[i], params.float_activation_min,
@@ -2669,7 +2676,7 @@ inline void SubWithActivation(const ArithmeticParams& params,
                               int32* output_data) {
   gemmlowp::ScopedProfilingLabel label("SubWithActivation/int32");
   const int flat_size =
-      MatchingFlatSize(input1_shape, input2_shape, input2_shape);
+      MatchingElementsSize(input1_shape, input2_shape, output_shape);
   for (int i = 0; i < flat_size; ++i) {
     output_data[i] = ActivationFunctionWithMinMax(
         input1_data[i] - input2_data[i], params.quantized_activation_min,
@@ -2686,7 +2693,7 @@ inline void SubWithActivation(const ArithmeticParams& params,
                               float* output_data) {
   gemmlowp::ScopedProfilingLabel label("SubWithActivation/float");
   const int flat_size =
-      MatchingFlatSize(input1_shape, input2_shape, input2_shape);
+      MatchingElementsSize(input1_shape, input2_shape, output_shape);
   for (int i = 0; i < flat_size; ++i) {
     output_data[i] = ActivationFunctionWithMinMax(
         input1_data[i] - input2_data[i], params.float_activation_min,
diff --git a/tensorflow/lite/kernels/internal/reference/add.h b/tensorflow/lite/kernels/internal/reference/add.h
index 5193a586fd0..d0c40912091 100644
--- a/tensorflow/lite/kernels/internal/reference/add.h
+++ b/tensorflow/lite/kernels/internal/reference/add.h
@@ -28,7 +28,7 @@ inline void Add(const ArithmeticParams& params,
                 const RuntimeShape& input2_shape, const T* input2_data,
                 const RuntimeShape& output_shape, T* output_data) {
   const int flat_size =
-      MatchingFlatSize(input1_shape, input2_shape, output_shape);
+      MatchingElementsSize(input1_shape, input2_shape, output_shape);
   for (int i = 0; i < flat_size; ++i) {
     output_data[i] = ActivationFunctionWithMinMax(
         input1_data[i] + input2_data[i], params.quantized_activation_min,
@@ -40,8 +40,9 @@ inline void Add(const ArithmeticParams& params,
                 const RuntimeShape& input1_shape, const float* input1_data,
                 const RuntimeShape& input2_shape, const float* input2_data,
                 const RuntimeShape& output_shape, float* output_data) {
-  const int size = MatchingFlatSize(input1_shape, input2_shape, output_shape);
-  for (int i = 0; i < size; i++) {
+  const int flat_size =
+      MatchingElementsSize(input1_shape, input2_shape, output_shape);
+  for (int i = 0; i < flat_size; i++) {
     auto x = input1_data[i] + input2_data[i];
     output_data[i] = ActivationFunctionWithMinMax(
         x, params.float_activation_min, params.float_activation_max);
@@ -122,7 +123,7 @@ inline void Add(const ArithmeticParams& params,
   TFLITE_DCHECK_LE(params.quantized_activation_min,
                    params.quantized_activation_max);
   const int flat_size =
-      MatchingFlatSize(input1_shape, input2_shape, output_shape);
+      MatchingElementsSize(input1_shape, input2_shape, output_shape);
 
   TFLITE_DCHECK_GT(params.input1_offset, -256);
   TFLITE_DCHECK_GT(params.input2_offset, -256);
@@ -140,7 +141,7 @@ inline void Add(const ArithmeticParams& params,
 
   const int input1_shift = params.input1_shift;
   const int flat_size =
-      MatchingFlatSize(output_shape, input1_shape, input2_shape);
+      MatchingElementsSize(input1_shape, input2_shape, output_shape);
   const int16 output_activation_min = params.quantized_activation_min;
   const int16 output_activation_max = params.quantized_activation_max;
 
diff --git a/tensorflow/lite/kernels/internal/reference/integer_ops/add.h b/tensorflow/lite/kernels/internal/reference/integer_ops/add.h
index e10092bafb5..69b42e08a6d 100644
--- a/tensorflow/lite/kernels/internal/reference/integer_ops/add.h
+++ b/tensorflow/lite/kernels/internal/reference/integer_ops/add.h
@@ -64,7 +64,7 @@ inline void Add(const ArithmeticParams& params,
   TFLITE_DCHECK_LE(params.quantized_activation_min,
                    params.quantized_activation_max);
   const int flat_size =
-      MatchingFlatSize(input1_shape, input2_shape, output_shape);
+      MatchingElementsSize(input1_shape, input2_shape, output_shape);
 
   const int32_t int8_max_value = std::numeric_limits<int8_t>::max();
   TFLITE_DCHECK_GE(params.input1_offset, -1 * int8_max_value);
diff --git a/tensorflow/lite/kernels/internal/reference/integer_ops/mul.h b/tensorflow/lite/kernels/internal/reference/integer_ops/mul.h
index 9c629ff2b8e..f054d07f9c6 100644
--- a/tensorflow/lite/kernels/internal/reference/integer_ops/mul.h
+++ b/tensorflow/lite/kernels/internal/reference/integer_ops/mul.h
@@ -48,7 +48,7 @@ inline void Mul(const ArithmeticParams& params,
                    params.quantized_activation_max);
   gemmlowp::ScopedProfilingLabel label("Mul/8bit");
   const int flat_size =
-      MatchingFlatSize(input1_shape, input2_shape, output_shape);
+      MatchingElementsSize(input1_shape, input2_shape, output_shape);
 
   MulElementwise(flat_size, params, input1_data, input2_data, output_data);
 }
@@ -65,7 +65,7 @@ inline void Mul(const ArithmeticParams& params,
   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
 
   const int flat_size =
-      MatchingFlatSize(input1_shape, input2_shape, output_shape);
+      MatchingElementsSize(input1_shape, input2_shape, output_shape);
 
   for (int i = 0; i < flat_size; i++) {
     // F0 uses 0 integer bits, range [-1, 1].
diff --git a/tensorflow/lite/kernels/internal/reference/reference_ops.h b/tensorflow/lite/kernels/internal/reference/reference_ops.h
index 21fa7de92cf..304ba5d9a0c 100644
--- a/tensorflow/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/lite/kernels/internal/reference/reference_ops.h
@@ -350,7 +350,7 @@ inline void Mul(const ArithmeticParams& params,
   GetActivationParams(params, &output_activation_min, &output_activation_max);
 
   const int flat_size =
-      MatchingFlatSize(input1_shape, input2_shape, output_shape);
+      MatchingElementsSize(input1_shape, input2_shape, output_shape);
   for (int i = 0; i < flat_size; ++i) {
     output_data[i] = ActivationFunctionWithMinMax(
         input1_data[i] * input2_data[i], output_activation_min,
@@ -444,7 +444,7 @@ inline void Mul(const ArithmeticParams& params,
                    params.quantized_activation_max);
   gemmlowp::ScopedProfilingLabel label("Mul/8bit");
   const int flat_size =
-      MatchingFlatSize(input1_shape, input2_shape, output_shape);
+      MatchingElementsSize(input1_shape, input2_shape, output_shape);
 
   MulElementwise(flat_size, params, input1_data, input2_data, output_data);
 }
@@ -551,7 +551,7 @@ inline void Mul(const ArithmeticParams& params,
   gemmlowp::ScopedProfilingLabel label("Mul/Int16");
 
   const int flat_size =
-      MatchingFlatSize(input1_shape, input2_shape, output_shape);
+      MatchingElementsSize(input1_shape, input2_shape, output_shape);
 
   for (int i = 0; i < flat_size; i++) {
     // F0 uses 0 integer bits, range [-1, 1].
@@ -574,7 +574,7 @@ inline void Mul(const ArithmeticParams& params,
   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
 
   const int flat_size =
-      MatchingFlatSize(input1_shape, input2_shape, output_shape);
+      MatchingElementsSize(input1_shape, input2_shape, output_shape);
 
   for (int i = 0; i < flat_size; i++) {
     // F0 uses 0 integer bits, range [-1, 1].
@@ -655,7 +655,7 @@ inline void Div(const ArithmeticParams& params,
   GetActivationParams(params, &output_activation_min, &output_activation_max);
 
   const int flat_size =
-      MatchingFlatSize(input1_shape, input2_shape, output_shape);
+      MatchingElementsSize(input1_shape, input2_shape, output_shape);
   for (int i = 0; i < flat_size; ++i) {
     output_data[i] = ActivationFunctionWithMinMax(
         input1_data[i] / input2_data[i], output_activation_min,
@@ -706,7 +706,7 @@ inline void Div(const ArithmeticParams& params,
                    params.quantized_activation_max);
   gemmlowp::ScopedProfilingLabel label("Div/8bit");
   const int flat_size =
-      MatchingFlatSize(input1_shape, input2_shape, output_shape);
+      MatchingElementsSize(input1_shape, input2_shape, output_shape);
 
   DivElementwise(flat_size, params, input1_data, input2_data, output_data);
 }
@@ -779,7 +779,7 @@ inline void SubNonBroadcast(const ArithmeticParams& params,
                             const RuntimeShape& output_shape,
                             float* output_data) {
   const int flat_size =
-      MatchingFlatSize(input1_shape, input2_shape, output_shape);
+      MatchingElementsSize(input1_shape, input2_shape, output_shape);
   for (int i = 0; i < flat_size; ++i) {
     output_data[i] = ActivationFunctionWithMinMax(
         input1_data[i] - input2_data[i], params.float_activation_min,
@@ -795,7 +795,7 @@ inline void SubNonBroadcast(const ArithmeticParams& params,
                             const RuntimeShape& output_shape,
                             int32* output_data) {
   const int flat_size =
-      MatchingFlatSize(input1_shape, input2_shape, output_shape);
+      MatchingElementsSize(input1_shape, input2_shape, output_shape);
   for (int i = 0; i < flat_size; ++i) {
     output_data[i] = ActivationFunctionWithMinMax(
         input1_data[i] - input2_data[i], params.quantized_activation_min,
@@ -1043,7 +1043,7 @@ inline void SubWithActivation(const ArithmeticParams& params,
                               int32* output_data) {
   gemmlowp::ScopedProfilingLabel label("SubWithActivation");
   const int flat_size =
-      MatchingFlatSize(input1_shape, input2_shape, input2_shape);
+      MatchingElementsSize(input1_shape, input2_shape, output_shape);
   for (int i = 0; i < flat_size; ++i) {
     output_data[i] = ActivationFunctionWithMinMax(
         input1_data[i] - input2_data[i], params.quantized_activation_min,
@@ -1059,7 +1059,7 @@ inline void SubWithActivation(const ArithmeticParams& params,
                               const RuntimeShape& output_shape,
                               float* output_data) {
   const int flat_size =
-      MatchingFlatSize(input1_shape, input2_shape, input2_shape);
+      MatchingElementsSize(input1_shape, input2_shape, output_shape);
   for (int i = 0; i < flat_size; ++i) {
     output_data[i] = ActivationFunctionWithMinMax(
         input1_data[i] - input2_data[i], params.float_activation_min,
@@ -1074,7 +1074,7 @@ inline void Sub16(const ArithmeticParams& params,
   gemmlowp::ScopedProfilingLabel label("Sub/Int16");
   const int input1_shift = params.input1_shift;
   const int flat_size =
-      MatchingFlatSize(output_shape, input1_shape, input2_shape);
+      MatchingElementsSize(input1_shape, input2_shape, output_shape);
   const int16 output_activation_min = params.quantized_activation_min;
   const int16 output_activation_max = params.quantized_activation_max;
 
diff --git a/tensorflow/lite/kernels/internal/types.h b/tensorflow/lite/kernels/internal/types.h
index eb7b630c574..1a4a4ee84c3 100644
--- a/tensorflow/lite/kernels/internal/types.h
+++ b/tensorflow/lite/kernels/internal/types.h
@@ -457,6 +457,25 @@ inline int RequiredBufferSizeForDims(const Dims<4>& dims) {
   return FlatSize(dims);
 }
 
+inline int MatchingElementsSize(const RuntimeShape& shape,
+                                const RuntimeShape& check_shape_0) {
+  const int size_1 = shape.FlatSize();
+  const int size_2 = check_shape_0.FlatSize();
+  TFLITE_CHECK_EQ(size_1, size_2);
+  return size_1;
+}
+
+inline int MatchingElementsSize(const RuntimeShape& shape,
+                                const RuntimeShape& check_shape_0,
+                                const RuntimeShape& check_shape_1) {
+  const int size_1 = shape.FlatSize();
+  const int size_2 = check_shape_0.FlatSize();
+  const int size_3 = check_shape_1.FlatSize();
+  TFLITE_CHECK_EQ(size_1, size_2);
+  TFLITE_CHECK_EQ(size_2, size_3);
+  return size_1;
+}
+
 // Flat size calculation, checking that dimensions match with one or more other
 // arrays.
 inline int MatchingFlatSize(const RuntimeShape& shape,
diff --git a/tensorflow/lite/kernels/mul.cc b/tensorflow/lite/kernels/mul.cc
index 9e2c3c81780..9feb1794076 100644
--- a/tensorflow/lite/kernels/mul.cc
+++ b/tensorflow/lite/kernels/mul.cc
@@ -110,8 +110,6 @@ void EvalMul(TfLiteContext* context, TfLiteNode* node, TfLiteMulParams* params,
              const OpData* data, const TfLiteTensor* input1,
              const TfLiteTensor* input2, TfLiteTensor* output) {
   tflite::ArithmeticParams op_params;
-  // requires_flat_size_broadcast is used for BroadcastMul4DSlow.
-  const bool requires_flat_size_broadcast = !HaveSameShapes(input1, input2);
   const bool need_broadcast = optimized_ops::ProcessBroadcastShapes(
       GetTensorShape(input1), GetTensorShape(input2), &op_params);
 #define TF_LITE_MUL(type, opname, data_type)                             \
@@ -127,13 +125,13 @@ void EvalMul(TfLiteContext* context, TfLiteNode* node, TfLiteMulParams* params,
 
   if (output->type == kTfLiteInt32) {
     if (kernel_type == kReference) {
-      if (requires_flat_size_broadcast) {
+      if (need_broadcast) {
         TF_LITE_MUL(reference_ops, BroadcastMul4DSlow, int32_t);
       } else {
         TF_LITE_MUL(reference_ops, Mul, int32_t);
       }
     } else {
-      if (requires_flat_size_broadcast) {
+      if (need_broadcast) {
         TF_LITE_MUL(optimized_ops, BroadcastMul4DSlow, int32_t);
       } else {
         TF_LITE_MUL(optimized_ops, Mul, int32_t);
@@ -141,7 +139,7 @@ void EvalMul(TfLiteContext* context, TfLiteNode* node, TfLiteMulParams* params,
     }
   } else if (output->type == kTfLiteFloat32) {
     if (kernel_type == kReference) {
-      if (requires_flat_size_broadcast) {
+      if (need_broadcast) {
         TF_LITE_MUL(reference_ops, BroadcastMul4DSlow, float);
       } else {
         TF_LITE_MUL(reference_ops, Mul, float);
@@ -149,8 +147,6 @@ void EvalMul(TfLiteContext* context, TfLiteNode* node, TfLiteMulParams* params,
     } else {
       if (need_broadcast) {
         TF_LITE_MUL(optimized_ops, BroadcastMulFivefold, float);
-      } else if (requires_flat_size_broadcast) {
-        TF_LITE_MUL(optimized_ops, BroadcastMul4DSlow, float);
       } else {
         TF_LITE_MUL(optimized_ops, Mul, float);
       }