diff --git a/tensorflow/lite/kernels/add.cc b/tensorflow/lite/kernels/add.cc index f224cb33eb0..17214d53ea1 100644 --- a/tensorflow/lite/kernels/add.cc +++ b/tensorflow/lite/kernels/add.cc @@ -177,8 +177,6 @@ void EvalAdd(TfLiteContext* context, TfLiteNode* node, TfLiteAddParams* params, const OpData* data, const TfLiteTensor* input1, const TfLiteTensor* input2, TfLiteTensor* output) { tflite::ArithmeticParams op_params; - // requires_flat_size_broadcast is used for BroadcastAdd4DSlow. - const bool requires_flat_size_broadcast = !HaveSameShapes(input1, input2); const bool need_broadcast = optimized_ops::ProcessBroadcastShapes( GetTensorShape(input1), GetTensorShape(input2), &op_params); #define TF_LITE_ADD(type, opname, data_type) \ @@ -193,13 +191,13 @@ void EvalAdd(TfLiteContext* context, TfLiteNode* node, TfLiteAddParams* params, GetTensorData<data_type>(output)) if (output->type == kTfLiteInt32) { if (kernel_type == kReference) { - if (requires_flat_size_broadcast) { + if (need_broadcast) { TF_LITE_ADD(reference_ops, BroadcastAdd4DSlow, int32_t); } else { TF_LITE_ADD(reference_ops, Add, int32_t); } } else { - if (requires_flat_size_broadcast) { + if (need_broadcast) { TF_LITE_ADD(optimized_ops, BroadcastAdd4DSlow, int32_t); } else { TF_LITE_ADD(optimized_ops, Add, int32_t); @@ -207,7 +205,7 @@ void EvalAdd(TfLiteContext* context, TfLiteNode* node, TfLiteAddParams* params, } } else if (output->type == kTfLiteFloat32) { if (kernel_type == kReference) { - if (requires_flat_size_broadcast) { + if (need_broadcast) { TF_LITE_ADD(reference_ops, BroadcastAdd4DSlow, float); } else { TF_LITE_ADD(reference_ops, Add, float); @@ -215,8 +213,6 @@ void EvalAdd(TfLiteContext* context, TfLiteNode* node, TfLiteAddParams* params, } else { if (need_broadcast) { TF_LITE_ADD(optimized_ops, BroadcastAddFivefold, float); - } else if (requires_flat_size_broadcast) { - TF_LITE_ADD(optimized_ops, BroadcastAdd4DSlow, float); } else { TF_LITE_ADD(optimized_ops, Add, float); } diff --git a/tensorflow/lite/kernels/add_test.cc b/tensorflow/lite/kernels/add_test.cc index 9dd7df147c8..ef97b7785e1 100644 --- a/tensorflow/lite/kernels/add_test.cc +++ b/tensorflow/lite/kernels/add_test.cc @@ -190,6 +190,24 @@ TEST(IntegerAddOpModel, WithBroadcast) { } } +TEST(IntegerAddOpModel, Int32MultiDimBroadcast) { + IntegerAddOpModel m({TensorType_INT32, {1, 2}}, {TensorType_INT32, {2, 1}}, + {TensorType_INT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor<int32_t>(m.input1(), {3, 5}); + m.PopulateTensor<int32_t>(m.input2(), {1, 4}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({4, 6, 7, 9})); +} + +TEST(IntegerAddOpModel, Float32MultiDimBroadcast) { + FloatAddOpModel m({TensorType_FLOAT32, {1, 2}}, {TensorType_FLOAT32, {2, 1}}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor<float>(m.input1(), {3, 5}); + m.PopulateTensor<float>(m.input2(), {1, 4}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({4, 6, 7, 9})); +} + template <TensorType tensor_type, typename integer_dtype> void QuantizedTestsNoActivation() { float kQuantizedTolerance = GetTolerance(-1.0, 1.0); diff --git a/tensorflow/lite/kernels/internal/optimized/integer_ops/add.h b/tensorflow/lite/kernels/internal/optimized/integer_ops/add.h index 1abf89a8e38..253944ca3f1 100644 --- a/tensorflow/lite/kernels/internal/optimized/integer_ops/add.h +++ b/tensorflow/lite/kernels/internal/optimized/integer_ops/add.h @@ -221,7 +221,7 @@ inline void Add(const ArithmeticParams& params, params.quantized_activation_max); gemmlowp::ScopedProfilingLabel label("AddInt8/8bit"); const int flat_size = - MatchingFlatSize(input1_shape, input2_shape, output_shape); + MatchingElementsSize(input1_shape, input2_shape, output_shape); TFLITE_DCHECK_GT(params.input1_offset, -256); TFLITE_DCHECK_GT(params.input2_offset, -256); diff --git a/tensorflow/lite/kernels/internal/optimized/integer_ops/mul.h b/tensorflow/lite/kernels/internal/optimized/integer_ops/mul.h index 08b8da09915..74b9d4b6a9e 100644 --- a/tensorflow/lite/kernels/internal/optimized/integer_ops/mul.h +++ b/tensorflow/lite/kernels/internal/optimized/integer_ops/mul.h @@ -175,7 +175,7 @@ inline void Mul(const ArithmeticParams& params, params.quantized_activation_max); gemmlowp::ScopedProfilingLabel label("MulInt8/8bit"); const int flat_size = - MatchingFlatSize(input1_shape, input2_shape, output_shape); + MatchingElementsSize(input1_shape, input2_shape, output_shape); MulElementwise(flat_size, params, input1_data, input2_data, output_data); } diff --git a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h index 787cc4cd68a..ccb66ce819f 100644 --- a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h @@ -1543,7 +1543,7 @@ inline void Add(const ArithmeticParams& params, const RuntimeShape& output_shape, float* output_data) { gemmlowp::ScopedProfilingLabel label("Add"); const int flat_size = - MatchingFlatSize(input1_shape, input2_shape, output_shape); + MatchingElementsSize(input1_shape, input2_shape, output_shape); AddElementwise(flat_size, params, input1_data, input2_data, output_data); } @@ -1782,7 +1782,7 @@ inline void Add(const ArithmeticParams& params, params.quantized_activation_max); gemmlowp::ScopedProfilingLabel label("Add/8bit"); const int flat_size = - MatchingFlatSize(input1_shape, input2_shape, output_shape); + MatchingElementsSize(input1_shape, input2_shape, output_shape); TFLITE_DCHECK_GT(params.input1_offset, -256); TFLITE_DCHECK_GT(params.input2_offset, -256); @@ -1801,7 +1801,7 @@ inline void Add(const ArithmeticParams& params, const int input1_shift = params.input1_shift; const int flat_size = - MatchingFlatSize(output_shape, input1_shape, input2_shape); + MatchingElementsSize(input1_shape, input2_shape, output_shape); const int16 output_activation_min = params.quantized_activation_min; const int16 output_activation_max = params.quantized_activation_max; @@ -1846,8 +1846,10 @@ inline void Add(const ArithmeticParams& params, auto scalar = input1_data[0]; output_map.array() = scalar + input2_map.array(); } else { - // Should not come here. - TFLITE_DCHECK(false); + reference_ops::BroadcastAdd4DSlow(params, input1_shape, input1_data, + input2_shape, input2_data, output_shape, + output_data); + return; } output_map = output_map.cwiseMax(params.quantized_activation_min); output_map = output_map.cwiseMin(params.quantized_activation_max); @@ -2097,7 +2099,7 @@ inline void Mul(const ArithmeticParams& params, gemmlowp::ScopedProfilingLabel label("Mul"); const int flat_size = - MatchingFlatSize(input1_shape, input2_shape, output_shape); + MatchingElementsSize(input1_shape, input2_shape, output_shape); MulElementwise(flat_size, params, input1_data, input2_data, output_data); } @@ -2108,7 +2110,7 @@ inline void Mul(const ArithmeticParams& params, gemmlowp::ScopedProfilingLabel label("Mul/int32/activation"); const int flat_size = - MatchingFlatSize(input1_shape, input2_shape, output_shape); + MatchingElementsSize(input1_shape, input2_shape, output_shape); const int32 output_activation_min = params.quantized_activation_min; const int32 output_activation_max = params.quantized_activation_max; for (int i = 0; i < flat_size; ++i) { @@ -2139,8 +2141,9 @@ inline void MulNoActivation(const ArithmeticParams& params, auto scalar = input1_data[0]; output_map.array() = scalar * input2_map.array(); } else { - // Should not come here. - TFLITE_DCHECK(false); + reference_ops::BroadcastMul4DSlow(params, input1_shape, input1_data, + input2_shape, input2_data, output_shape, + output_data); } } @@ -2153,7 +2156,7 @@ inline void Mul(const ArithmeticParams& params, // properly optimized version. const int flat_size = - MatchingFlatSize(input1_shape, input2_shape, output_shape); + MatchingElementsSize(input1_shape, input2_shape, output_shape); for (int i = 0; i < flat_size; i++) { // F0 uses 0 integer bits, range [-1, 1]. @@ -2178,7 +2181,7 @@ inline void Mul(const ArithmeticParams& params, TFLITE_DCHECK_LE(output_activation_min, output_activation_max); const int flat_size = - MatchingFlatSize(input1_shape, input2_shape, output_shape); + MatchingElementsSize(input1_shape, input2_shape, output_shape); for (int i = 0; i < flat_size; i++) { // F0 uses 0 integer bits, range [-1, 1]. @@ -2342,6 +2345,8 @@ inline void MulSimpleBroadcast(int size, const ArithmeticParams& params, } // Broadcast mul that can often be used for inner loop of broadcast Mul. +// This function will handle scalar_value (LHS) * vector_values (RHS). +// Since it's a float function, input params does not matter here. inline void MulSimpleBroadcast(int size, const ArithmeticParams& params, const float broadcast_value, const float* input2_data, float* output_data) { @@ -2380,7 +2385,7 @@ inline void Mul(const ArithmeticParams& params, params.quantized_activation_max); gemmlowp::ScopedProfilingLabel label("Mul/8bit"); const int flat_size = - MatchingFlatSize(input1_shape, input2_shape, output_shape); + MatchingElementsSize(input1_shape, input2_shape, output_shape); MulElementwise(flat_size, params, input1_data, input2_data, output_data); } @@ -2509,6 +2514,8 @@ inline void BroadcastMulFivefold(const ArithmeticParams& params, for (int i1 = 0; i1 < y1; ++i1) { input2_data_ptr = input2_data_reset; for (int i2 = 0; i2 < y2; ++i2) { + // The input may be switched here, but the common parameters here + // do not matter as they will not influence the float math execution. MulSimpleBroadcast(y3, params, *input1_data_ptr, input2_data_ptr, output_data_ptr); input2_data_ptr += y3; @@ -2652,7 +2659,7 @@ inline void SubNonBroadcast(const ArithmeticParams& params, float* output_data) { gemmlowp::ScopedProfilingLabel label("SubNonBroadcast"); const int flat_size = - MatchingFlatSize(input1_shape, input2_shape, output_shape); + MatchingElementsSize(input1_shape, input2_shape, output_shape); for (int i = 0; i < flat_size; ++i) { output_data[i] = ActivationFunctionWithMinMax( input1_data[i] - input2_data[i], params.float_activation_min, @@ -2669,7 +2676,7 @@ inline void SubWithActivation(const ArithmeticParams& params, int32* output_data) { gemmlowp::ScopedProfilingLabel label("SubWithActivation/int32"); const int flat_size = - MatchingFlatSize(input1_shape, input2_shape, input2_shape); + MatchingElementsSize(input1_shape, input2_shape, output_shape); for (int i = 0; i < flat_size; ++i) { output_data[i] = ActivationFunctionWithMinMax( input1_data[i] - input2_data[i], params.quantized_activation_min, @@ -2686,7 +2693,7 @@ inline void SubWithActivation(const ArithmeticParams& params, float* output_data) { gemmlowp::ScopedProfilingLabel label("SubWithActivation/float"); const int flat_size = - MatchingFlatSize(input1_shape, input2_shape, input2_shape); + MatchingElementsSize(input1_shape, input2_shape, output_shape); for (int i = 0; i < flat_size; ++i) { output_data[i] = ActivationFunctionWithMinMax( input1_data[i] - input2_data[i], params.float_activation_min, diff --git a/tensorflow/lite/kernels/internal/reference/add.h b/tensorflow/lite/kernels/internal/reference/add.h index 5193a586fd0..d0c40912091 100644 --- a/tensorflow/lite/kernels/internal/reference/add.h +++ b/tensorflow/lite/kernels/internal/reference/add.h @@ -28,7 +28,7 @@ inline void Add(const ArithmeticParams& params, const RuntimeShape& input2_shape, const T* input2_data, const RuntimeShape& output_shape, T* output_data) { const int flat_size = - MatchingFlatSize(input1_shape, input2_shape, output_shape); + MatchingElementsSize(input1_shape, input2_shape, output_shape); for (int i = 0; i < flat_size; ++i) { output_data[i] = ActivationFunctionWithMinMax( input1_data[i] + input2_data[i], params.quantized_activation_min, @@ -40,8 +40,9 @@ inline void Add(const ArithmeticParams& params, const RuntimeShape& input1_shape, const float* input1_data, const RuntimeShape& input2_shape, const float* input2_data, const RuntimeShape& output_shape, float* output_data) { - const int size = MatchingFlatSize(input1_shape, input2_shape, output_shape); - for (int i = 0; i < size; i++) { + const int flat_size = + MatchingElementsSize(input1_shape, input2_shape, output_shape); + for (int i = 0; i < flat_size; i++) { auto x = input1_data[i] + input2_data[i]; output_data[i] = ActivationFunctionWithMinMax( x, params.float_activation_min, params.float_activation_max); @@ -122,7 +123,7 @@ inline void Add(const ArithmeticParams& params, TFLITE_DCHECK_LE(params.quantized_activation_min, params.quantized_activation_max); const int flat_size = - MatchingFlatSize(input1_shape, input2_shape, output_shape); + MatchingElementsSize(input1_shape, input2_shape, output_shape); TFLITE_DCHECK_GT(params.input1_offset, -256); TFLITE_DCHECK_GT(params.input2_offset, -256); @@ -140,7 +141,7 @@ inline void Add(const ArithmeticParams& params, const int input1_shift = params.input1_shift; const int flat_size = - MatchingFlatSize(output_shape, input1_shape, input2_shape); + MatchingElementsSize(input1_shape, input2_shape, output_shape); const int16 output_activation_min = params.quantized_activation_min; const int16 output_activation_max = params.quantized_activation_max; diff --git a/tensorflow/lite/kernels/internal/reference/integer_ops/add.h b/tensorflow/lite/kernels/internal/reference/integer_ops/add.h index e10092bafb5..69b42e08a6d 100644 --- a/tensorflow/lite/kernels/internal/reference/integer_ops/add.h +++ b/tensorflow/lite/kernels/internal/reference/integer_ops/add.h @@ -64,7 +64,7 @@ inline void Add(const ArithmeticParams& params, TFLITE_DCHECK_LE(params.quantized_activation_min, params.quantized_activation_max); const int flat_size = - MatchingFlatSize(input1_shape, input2_shape, output_shape); + MatchingElementsSize(input1_shape, input2_shape, output_shape); const int32_t int8_max_value = std::numeric_limits<int8_t>::max(); TFLITE_DCHECK_GE(params.input1_offset, -1 * int8_max_value); diff --git a/tensorflow/lite/kernels/internal/reference/integer_ops/mul.h b/tensorflow/lite/kernels/internal/reference/integer_ops/mul.h index 9c629ff2b8e..f054d07f9c6 100644 --- a/tensorflow/lite/kernels/internal/reference/integer_ops/mul.h +++ b/tensorflow/lite/kernels/internal/reference/integer_ops/mul.h @@ -48,7 +48,7 @@ inline void Mul(const ArithmeticParams& params, params.quantized_activation_max); gemmlowp::ScopedProfilingLabel label("Mul/8bit"); const int flat_size = - MatchingFlatSize(input1_shape, input2_shape, output_shape); + MatchingElementsSize(input1_shape, input2_shape, output_shape); MulElementwise(flat_size, params, input1_data, input2_data, output_data); } @@ -65,7 +65,7 @@ inline void Mul(const ArithmeticParams& params, TFLITE_DCHECK_LE(output_activation_min, output_activation_max); const int flat_size = - MatchingFlatSize(input1_shape, input2_shape, output_shape); + MatchingElementsSize(input1_shape, input2_shape, output_shape); for (int i = 0; i < flat_size; i++) { // F0 uses 0 integer bits, range [-1, 1]. diff --git a/tensorflow/lite/kernels/internal/reference/reference_ops.h b/tensorflow/lite/kernels/internal/reference/reference_ops.h index 21fa7de92cf..304ba5d9a0c 100644 --- a/tensorflow/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/lite/kernels/internal/reference/reference_ops.h @@ -350,7 +350,7 @@ inline void Mul(const ArithmeticParams& params, GetActivationParams(params, &output_activation_min, &output_activation_max); const int flat_size = - MatchingFlatSize(input1_shape, input2_shape, output_shape); + MatchingElementsSize(input1_shape, input2_shape, output_shape); for (int i = 0; i < flat_size; ++i) { output_data[i] = ActivationFunctionWithMinMax( input1_data[i] * input2_data[i], output_activation_min, @@ -444,7 +444,7 @@ inline void Mul(const ArithmeticParams& params, params.quantized_activation_max); gemmlowp::ScopedProfilingLabel label("Mul/8bit"); const int flat_size = - MatchingFlatSize(input1_shape, input2_shape, output_shape); + MatchingElementsSize(input1_shape, input2_shape, output_shape); MulElementwise(flat_size, params, input1_data, input2_data, output_data); } @@ -551,7 +551,7 @@ inline void Mul(const ArithmeticParams& params, gemmlowp::ScopedProfilingLabel label("Mul/Int16"); const int flat_size = - MatchingFlatSize(input1_shape, input2_shape, output_shape); + MatchingElementsSize(input1_shape, input2_shape, output_shape); for (int i = 0; i < flat_size; i++) { // F0 uses 0 integer bits, range [-1, 1]. @@ -574,7 +574,7 @@ inline void Mul(const ArithmeticParams& params, TFLITE_DCHECK_LE(output_activation_min, output_activation_max); const int flat_size = - MatchingFlatSize(input1_shape, input2_shape, output_shape); + MatchingElementsSize(input1_shape, input2_shape, output_shape); for (int i = 0; i < flat_size; i++) { // F0 uses 0 integer bits, range [-1, 1]. @@ -655,7 +655,7 @@ inline void Div(const ArithmeticParams& params, GetActivationParams(params, &output_activation_min, &output_activation_max); const int flat_size = - MatchingFlatSize(input1_shape, input2_shape, output_shape); + MatchingElementsSize(input1_shape, input2_shape, output_shape); for (int i = 0; i < flat_size; ++i) { output_data[i] = ActivationFunctionWithMinMax( input1_data[i] / input2_data[i], output_activation_min, @@ -706,7 +706,7 @@ inline void Div(const ArithmeticParams& params, params.quantized_activation_max); gemmlowp::ScopedProfilingLabel label("Div/8bit"); const int flat_size = - MatchingFlatSize(input1_shape, input2_shape, output_shape); + MatchingElementsSize(input1_shape, input2_shape, output_shape); DivElementwise(flat_size, params, input1_data, input2_data, output_data); } @@ -779,7 +779,7 @@ inline void SubNonBroadcast(const ArithmeticParams& params, const RuntimeShape& output_shape, float* output_data) { const int flat_size = - MatchingFlatSize(input1_shape, input2_shape, output_shape); + MatchingElementsSize(input1_shape, input2_shape, output_shape); for (int i = 0; i < flat_size; ++i) { output_data[i] = ActivationFunctionWithMinMax( input1_data[i] - input2_data[i], params.float_activation_min, @@ -795,7 +795,7 @@ inline void SubNonBroadcast(const ArithmeticParams& params, const RuntimeShape& output_shape, int32* output_data) { const int flat_size = - MatchingFlatSize(input1_shape, input2_shape, output_shape); + MatchingElementsSize(input1_shape, input2_shape, output_shape); for (int i = 0; i < flat_size; ++i) { output_data[i] = ActivationFunctionWithMinMax( input1_data[i] - input2_data[i], params.quantized_activation_min, @@ -1043,7 +1043,7 @@ inline void SubWithActivation(const ArithmeticParams& params, int32* output_data) { gemmlowp::ScopedProfilingLabel label("SubWithActivation"); const int flat_size = - MatchingFlatSize(input1_shape, input2_shape, input2_shape); + MatchingElementsSize(input1_shape, input2_shape, output_shape); for (int i = 0; i < flat_size; ++i) { output_data[i] = ActivationFunctionWithMinMax( input1_data[i] - input2_data[i], params.quantized_activation_min, @@ -1059,7 +1059,7 @@ inline void SubWithActivation(const ArithmeticParams& params, const RuntimeShape& output_shape, float* output_data) { const int flat_size = - MatchingFlatSize(input1_shape, input2_shape, input2_shape); + MatchingElementsSize(input1_shape, input2_shape, output_shape); for (int i = 0; i < flat_size; ++i) { output_data[i] = ActivationFunctionWithMinMax( input1_data[i] - input2_data[i], params.float_activation_min, @@ -1074,7 +1074,7 @@ inline void Sub16(const ArithmeticParams& params, gemmlowp::ScopedProfilingLabel label("Sub/Int16"); const int input1_shift = params.input1_shift; const int flat_size = - MatchingFlatSize(output_shape, input1_shape, input2_shape); + MatchingElementsSize(input1_shape, input2_shape, output_shape); const int16 output_activation_min = params.quantized_activation_min; const int16 output_activation_max = params.quantized_activation_max; diff --git a/tensorflow/lite/kernels/internal/types.h b/tensorflow/lite/kernels/internal/types.h index eb7b630c574..1a4a4ee84c3 100644 --- a/tensorflow/lite/kernels/internal/types.h +++ b/tensorflow/lite/kernels/internal/types.h @@ -457,6 +457,25 @@ inline int RequiredBufferSizeForDims(const Dims<4>& dims) { return FlatSize(dims); } +inline int MatchingElementsSize(const RuntimeShape& shape, + const RuntimeShape& check_shape_0) { + const int size_1 = shape.FlatSize(); + const int size_2 = check_shape_0.FlatSize(); + TFLITE_CHECK_EQ(size_1, size_2); + return size_1; +} + +inline int MatchingElementsSize(const RuntimeShape& shape, + const RuntimeShape& check_shape_0, + const RuntimeShape& check_shape_1) { + const int size_1 = shape.FlatSize(); + const int size_2 = check_shape_0.FlatSize(); + const int size_3 = check_shape_1.FlatSize(); + TFLITE_CHECK_EQ(size_1, size_2); + TFLITE_CHECK_EQ(size_2, size_3); + return size_1; +} + // Flat size calculation, checking that dimensions match with one or more other // arrays. inline int MatchingFlatSize(const RuntimeShape& shape, diff --git a/tensorflow/lite/kernels/mul.cc b/tensorflow/lite/kernels/mul.cc index 9e2c3c81780..9feb1794076 100644 --- a/tensorflow/lite/kernels/mul.cc +++ b/tensorflow/lite/kernels/mul.cc @@ -110,8 +110,6 @@ void EvalMul(TfLiteContext* context, TfLiteNode* node, TfLiteMulParams* params, const OpData* data, const TfLiteTensor* input1, const TfLiteTensor* input2, TfLiteTensor* output) { tflite::ArithmeticParams op_params; - // requires_flat_size_broadcast is used for BroadcastMul4DSlow. - const bool requires_flat_size_broadcast = !HaveSameShapes(input1, input2); const bool need_broadcast = optimized_ops::ProcessBroadcastShapes( GetTensorShape(input1), GetTensorShape(input2), &op_params); #define TF_LITE_MUL(type, opname, data_type) \ @@ -127,13 +125,13 @@ void EvalMul(TfLiteContext* context, TfLiteNode* node, TfLiteMulParams* params, if (output->type == kTfLiteInt32) { if (kernel_type == kReference) { - if (requires_flat_size_broadcast) { + if (need_broadcast) { TF_LITE_MUL(reference_ops, BroadcastMul4DSlow, int32_t); } else { TF_LITE_MUL(reference_ops, Mul, int32_t); } } else { - if (requires_flat_size_broadcast) { + if (need_broadcast) { TF_LITE_MUL(optimized_ops, BroadcastMul4DSlow, int32_t); } else { TF_LITE_MUL(optimized_ops, Mul, int32_t); @@ -141,7 +139,7 @@ void EvalMul(TfLiteContext* context, TfLiteNode* node, TfLiteMulParams* params, } } else if (output->type == kTfLiteFloat32) { if (kernel_type == kReference) { - if (requires_flat_size_broadcast) { + if (need_broadcast) { TF_LITE_MUL(reference_ops, BroadcastMul4DSlow, float); } else { TF_LITE_MUL(reference_ops, Mul, float); @@ -149,8 +147,6 @@ void EvalMul(TfLiteContext* context, TfLiteNode* node, TfLiteMulParams* params, } else { if (need_broadcast) { TF_LITE_MUL(optimized_ops, BroadcastMulFivefold, float); - } else if (requires_flat_size_broadcast) { - TF_LITE_MUL(optimized_ops, BroadcastMul4DSlow, float); } else { TF_LITE_MUL(optimized_ops, Mul, float); }