remove unnecessary requires_flat_size_broadcast check.
PiperOrigin-RevId: 275175433 Change-Id: I084990420ed4a7f3cbb32ac554d0b2f65537be78
This commit is contained in:
parent
d0de62189a
commit
0a0dbde0a4
tensorflow/lite/kernels
@ -177,8 +177,6 @@ void EvalAdd(TfLiteContext* context, TfLiteNode* node, TfLiteAddParams* params,
|
||||
const OpData* data, const TfLiteTensor* input1,
|
||||
const TfLiteTensor* input2, TfLiteTensor* output) {
|
||||
tflite::ArithmeticParams op_params;
|
||||
// requires_flat_size_broadcast is used for BroadcastAdd4DSlow.
|
||||
const bool requires_flat_size_broadcast = !HaveSameShapes(input1, input2);
|
||||
const bool need_broadcast = optimized_ops::ProcessBroadcastShapes(
|
||||
GetTensorShape(input1), GetTensorShape(input2), &op_params);
|
||||
#define TF_LITE_ADD(type, opname, data_type) \
|
||||
@ -193,13 +191,13 @@ void EvalAdd(TfLiteContext* context, TfLiteNode* node, TfLiteAddParams* params,
|
||||
GetTensorData<data_type>(output))
|
||||
if (output->type == kTfLiteInt32) {
|
||||
if (kernel_type == kReference) {
|
||||
if (requires_flat_size_broadcast) {
|
||||
if (need_broadcast) {
|
||||
TF_LITE_ADD(reference_ops, BroadcastAdd4DSlow, int32_t);
|
||||
} else {
|
||||
TF_LITE_ADD(reference_ops, Add, int32_t);
|
||||
}
|
||||
} else {
|
||||
if (requires_flat_size_broadcast) {
|
||||
if (need_broadcast) {
|
||||
TF_LITE_ADD(optimized_ops, BroadcastAdd4DSlow, int32_t);
|
||||
} else {
|
||||
TF_LITE_ADD(optimized_ops, Add, int32_t);
|
||||
@ -207,7 +205,7 @@ void EvalAdd(TfLiteContext* context, TfLiteNode* node, TfLiteAddParams* params,
|
||||
}
|
||||
} else if (output->type == kTfLiteFloat32) {
|
||||
if (kernel_type == kReference) {
|
||||
if (requires_flat_size_broadcast) {
|
||||
if (need_broadcast) {
|
||||
TF_LITE_ADD(reference_ops, BroadcastAdd4DSlow, float);
|
||||
} else {
|
||||
TF_LITE_ADD(reference_ops, Add, float);
|
||||
@ -215,8 +213,6 @@ void EvalAdd(TfLiteContext* context, TfLiteNode* node, TfLiteAddParams* params,
|
||||
} else {
|
||||
if (need_broadcast) {
|
||||
TF_LITE_ADD(optimized_ops, BroadcastAddFivefold, float);
|
||||
} else if (requires_flat_size_broadcast) {
|
||||
TF_LITE_ADD(optimized_ops, BroadcastAdd4DSlow, float);
|
||||
} else {
|
||||
TF_LITE_ADD(optimized_ops, Add, float);
|
||||
}
|
||||
|
@ -190,6 +190,24 @@ TEST(IntegerAddOpModel, WithBroadcast) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(IntegerAddOpModel, Int32MultiDimBroadcast) {
|
||||
IntegerAddOpModel m({TensorType_INT32, {1, 2}}, {TensorType_INT32, {2, 1}},
|
||||
{TensorType_INT32, {}}, ActivationFunctionType_NONE);
|
||||
m.PopulateTensor<int32_t>(m.input1(), {3, 5});
|
||||
m.PopulateTensor<int32_t>(m.input2(), {1, 4});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({4, 6, 7, 9}));
|
||||
}
|
||||
|
||||
TEST(IntegerAddOpModel, Float32MultiDimBroadcast) {
|
||||
FloatAddOpModel m({TensorType_FLOAT32, {1, 2}}, {TensorType_FLOAT32, {2, 1}},
|
||||
{TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
|
||||
m.PopulateTensor<float>(m.input1(), {3, 5});
|
||||
m.PopulateTensor<float>(m.input2(), {1, 4});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({4, 6, 7, 9}));
|
||||
}
|
||||
|
||||
template <TensorType tensor_type, typename integer_dtype>
|
||||
void QuantizedTestsNoActivation() {
|
||||
float kQuantizedTolerance = GetTolerance(-1.0, 1.0);
|
||||
|
@ -221,7 +221,7 @@ inline void Add(const ArithmeticParams& params,
|
||||
params.quantized_activation_max);
|
||||
gemmlowp::ScopedProfilingLabel label("AddInt8/8bit");
|
||||
const int flat_size =
|
||||
MatchingFlatSize(input1_shape, input2_shape, output_shape);
|
||||
MatchingElementsSize(input1_shape, input2_shape, output_shape);
|
||||
|
||||
TFLITE_DCHECK_GT(params.input1_offset, -256);
|
||||
TFLITE_DCHECK_GT(params.input2_offset, -256);
|
||||
|
@ -175,7 +175,7 @@ inline void Mul(const ArithmeticParams& params,
|
||||
params.quantized_activation_max);
|
||||
gemmlowp::ScopedProfilingLabel label("MulInt8/8bit");
|
||||
const int flat_size =
|
||||
MatchingFlatSize(input1_shape, input2_shape, output_shape);
|
||||
MatchingElementsSize(input1_shape, input2_shape, output_shape);
|
||||
|
||||
MulElementwise(flat_size, params, input1_data, input2_data, output_data);
|
||||
}
|
||||
|
@ -1543,7 +1543,7 @@ inline void Add(const ArithmeticParams& params,
|
||||
const RuntimeShape& output_shape, float* output_data) {
|
||||
gemmlowp::ScopedProfilingLabel label("Add");
|
||||
const int flat_size =
|
||||
MatchingFlatSize(input1_shape, input2_shape, output_shape);
|
||||
MatchingElementsSize(input1_shape, input2_shape, output_shape);
|
||||
AddElementwise(flat_size, params, input1_data, input2_data, output_data);
|
||||
}
|
||||
|
||||
@ -1782,7 +1782,7 @@ inline void Add(const ArithmeticParams& params,
|
||||
params.quantized_activation_max);
|
||||
gemmlowp::ScopedProfilingLabel label("Add/8bit");
|
||||
const int flat_size =
|
||||
MatchingFlatSize(input1_shape, input2_shape, output_shape);
|
||||
MatchingElementsSize(input1_shape, input2_shape, output_shape);
|
||||
|
||||
TFLITE_DCHECK_GT(params.input1_offset, -256);
|
||||
TFLITE_DCHECK_GT(params.input2_offset, -256);
|
||||
@ -1801,7 +1801,7 @@ inline void Add(const ArithmeticParams& params,
|
||||
|
||||
const int input1_shift = params.input1_shift;
|
||||
const int flat_size =
|
||||
MatchingFlatSize(output_shape, input1_shape, input2_shape);
|
||||
MatchingElementsSize(input1_shape, input2_shape, output_shape);
|
||||
const int16 output_activation_min = params.quantized_activation_min;
|
||||
const int16 output_activation_max = params.quantized_activation_max;
|
||||
|
||||
@ -1846,8 +1846,10 @@ inline void Add(const ArithmeticParams& params,
|
||||
auto scalar = input1_data[0];
|
||||
output_map.array() = scalar + input2_map.array();
|
||||
} else {
|
||||
// Should not come here.
|
||||
TFLITE_DCHECK(false);
|
||||
reference_ops::BroadcastAdd4DSlow(params, input1_shape, input1_data,
|
||||
input2_shape, input2_data, output_shape,
|
||||
output_data);
|
||||
return;
|
||||
}
|
||||
output_map = output_map.cwiseMax(params.quantized_activation_min);
|
||||
output_map = output_map.cwiseMin(params.quantized_activation_max);
|
||||
@ -2097,7 +2099,7 @@ inline void Mul(const ArithmeticParams& params,
|
||||
gemmlowp::ScopedProfilingLabel label("Mul");
|
||||
|
||||
const int flat_size =
|
||||
MatchingFlatSize(input1_shape, input2_shape, output_shape);
|
||||
MatchingElementsSize(input1_shape, input2_shape, output_shape);
|
||||
MulElementwise(flat_size, params, input1_data, input2_data, output_data);
|
||||
}
|
||||
|
||||
@ -2108,7 +2110,7 @@ inline void Mul(const ArithmeticParams& params,
|
||||
gemmlowp::ScopedProfilingLabel label("Mul/int32/activation");
|
||||
|
||||
const int flat_size =
|
||||
MatchingFlatSize(input1_shape, input2_shape, output_shape);
|
||||
MatchingElementsSize(input1_shape, input2_shape, output_shape);
|
||||
const int32 output_activation_min = params.quantized_activation_min;
|
||||
const int32 output_activation_max = params.quantized_activation_max;
|
||||
for (int i = 0; i < flat_size; ++i) {
|
||||
@ -2139,8 +2141,9 @@ inline void MulNoActivation(const ArithmeticParams& params,
|
||||
auto scalar = input1_data[0];
|
||||
output_map.array() = scalar * input2_map.array();
|
||||
} else {
|
||||
// Should not come here.
|
||||
TFLITE_DCHECK(false);
|
||||
reference_ops::BroadcastMul4DSlow(params, input1_shape, input1_data,
|
||||
input2_shape, input2_data, output_shape,
|
||||
output_data);
|
||||
}
|
||||
}
|
||||
|
||||
@ -2153,7 +2156,7 @@ inline void Mul(const ArithmeticParams& params,
|
||||
// properly optimized version.
|
||||
|
||||
const int flat_size =
|
||||
MatchingFlatSize(input1_shape, input2_shape, output_shape);
|
||||
MatchingElementsSize(input1_shape, input2_shape, output_shape);
|
||||
|
||||
for (int i = 0; i < flat_size; i++) {
|
||||
// F0 uses 0 integer bits, range [-1, 1].
|
||||
@ -2178,7 +2181,7 @@ inline void Mul(const ArithmeticParams& params,
|
||||
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
|
||||
|
||||
const int flat_size =
|
||||
MatchingFlatSize(input1_shape, input2_shape, output_shape);
|
||||
MatchingElementsSize(input1_shape, input2_shape, output_shape);
|
||||
|
||||
for (int i = 0; i < flat_size; i++) {
|
||||
// F0 uses 0 integer bits, range [-1, 1].
|
||||
@ -2342,6 +2345,8 @@ inline void MulSimpleBroadcast(int size, const ArithmeticParams& params,
|
||||
}
|
||||
|
||||
// Broadcast mul that can often be used for inner loop of broadcast Mul.
|
||||
// This function will handle scalar_value (LHS) * vector_values (RHS).
|
||||
// Since it's a float function, input params does not matter here.
|
||||
inline void MulSimpleBroadcast(int size, const ArithmeticParams& params,
|
||||
const float broadcast_value,
|
||||
const float* input2_data, float* output_data) {
|
||||
@ -2380,7 +2385,7 @@ inline void Mul(const ArithmeticParams& params,
|
||||
params.quantized_activation_max);
|
||||
gemmlowp::ScopedProfilingLabel label("Mul/8bit");
|
||||
const int flat_size =
|
||||
MatchingFlatSize(input1_shape, input2_shape, output_shape);
|
||||
MatchingElementsSize(input1_shape, input2_shape, output_shape);
|
||||
|
||||
MulElementwise(flat_size, params, input1_data, input2_data, output_data);
|
||||
}
|
||||
@ -2509,6 +2514,8 @@ inline void BroadcastMulFivefold(const ArithmeticParams& params,
|
||||
for (int i1 = 0; i1 < y1; ++i1) {
|
||||
input2_data_ptr = input2_data_reset;
|
||||
for (int i2 = 0; i2 < y2; ++i2) {
|
||||
// The input may be switched here, but the common parameters here
|
||||
// do not matter as they will not influence the float math execution.
|
||||
MulSimpleBroadcast(y3, params, *input1_data_ptr, input2_data_ptr,
|
||||
output_data_ptr);
|
||||
input2_data_ptr += y3;
|
||||
@ -2652,7 +2659,7 @@ inline void SubNonBroadcast(const ArithmeticParams& params,
|
||||
float* output_data) {
|
||||
gemmlowp::ScopedProfilingLabel label("SubNonBroadcast");
|
||||
const int flat_size =
|
||||
MatchingFlatSize(input1_shape, input2_shape, output_shape);
|
||||
MatchingElementsSize(input1_shape, input2_shape, output_shape);
|
||||
for (int i = 0; i < flat_size; ++i) {
|
||||
output_data[i] = ActivationFunctionWithMinMax(
|
||||
input1_data[i] - input2_data[i], params.float_activation_min,
|
||||
@ -2669,7 +2676,7 @@ inline void SubWithActivation(const ArithmeticParams& params,
|
||||
int32* output_data) {
|
||||
gemmlowp::ScopedProfilingLabel label("SubWithActivation/int32");
|
||||
const int flat_size =
|
||||
MatchingFlatSize(input1_shape, input2_shape, input2_shape);
|
||||
MatchingElementsSize(input1_shape, input2_shape, output_shape);
|
||||
for (int i = 0; i < flat_size; ++i) {
|
||||
output_data[i] = ActivationFunctionWithMinMax(
|
||||
input1_data[i] - input2_data[i], params.quantized_activation_min,
|
||||
@ -2686,7 +2693,7 @@ inline void SubWithActivation(const ArithmeticParams& params,
|
||||
float* output_data) {
|
||||
gemmlowp::ScopedProfilingLabel label("SubWithActivation/float");
|
||||
const int flat_size =
|
||||
MatchingFlatSize(input1_shape, input2_shape, input2_shape);
|
||||
MatchingElementsSize(input1_shape, input2_shape, output_shape);
|
||||
for (int i = 0; i < flat_size; ++i) {
|
||||
output_data[i] = ActivationFunctionWithMinMax(
|
||||
input1_data[i] - input2_data[i], params.float_activation_min,
|
||||
|
@ -28,7 +28,7 @@ inline void Add(const ArithmeticParams& params,
|
||||
const RuntimeShape& input2_shape, const T* input2_data,
|
||||
const RuntimeShape& output_shape, T* output_data) {
|
||||
const int flat_size =
|
||||
MatchingFlatSize(input1_shape, input2_shape, output_shape);
|
||||
MatchingElementsSize(input1_shape, input2_shape, output_shape);
|
||||
for (int i = 0; i < flat_size; ++i) {
|
||||
output_data[i] = ActivationFunctionWithMinMax(
|
||||
input1_data[i] + input2_data[i], params.quantized_activation_min,
|
||||
@ -40,8 +40,9 @@ inline void Add(const ArithmeticParams& params,
|
||||
const RuntimeShape& input1_shape, const float* input1_data,
|
||||
const RuntimeShape& input2_shape, const float* input2_data,
|
||||
const RuntimeShape& output_shape, float* output_data) {
|
||||
const int size = MatchingFlatSize(input1_shape, input2_shape, output_shape);
|
||||
for (int i = 0; i < size; i++) {
|
||||
const int flat_size =
|
||||
MatchingElementsSize(input1_shape, input2_shape, output_shape);
|
||||
for (int i = 0; i < flat_size; i++) {
|
||||
auto x = input1_data[i] + input2_data[i];
|
||||
output_data[i] = ActivationFunctionWithMinMax(
|
||||
x, params.float_activation_min, params.float_activation_max);
|
||||
@ -122,7 +123,7 @@ inline void Add(const ArithmeticParams& params,
|
||||
TFLITE_DCHECK_LE(params.quantized_activation_min,
|
||||
params.quantized_activation_max);
|
||||
const int flat_size =
|
||||
MatchingFlatSize(input1_shape, input2_shape, output_shape);
|
||||
MatchingElementsSize(input1_shape, input2_shape, output_shape);
|
||||
|
||||
TFLITE_DCHECK_GT(params.input1_offset, -256);
|
||||
TFLITE_DCHECK_GT(params.input2_offset, -256);
|
||||
@ -140,7 +141,7 @@ inline void Add(const ArithmeticParams& params,
|
||||
|
||||
const int input1_shift = params.input1_shift;
|
||||
const int flat_size =
|
||||
MatchingFlatSize(output_shape, input1_shape, input2_shape);
|
||||
MatchingElementsSize(input1_shape, input2_shape, output_shape);
|
||||
const int16 output_activation_min = params.quantized_activation_min;
|
||||
const int16 output_activation_max = params.quantized_activation_max;
|
||||
|
||||
|
@ -64,7 +64,7 @@ inline void Add(const ArithmeticParams& params,
|
||||
TFLITE_DCHECK_LE(params.quantized_activation_min,
|
||||
params.quantized_activation_max);
|
||||
const int flat_size =
|
||||
MatchingFlatSize(input1_shape, input2_shape, output_shape);
|
||||
MatchingElementsSize(input1_shape, input2_shape, output_shape);
|
||||
|
||||
const int32_t int8_max_value = std::numeric_limits<int8_t>::max();
|
||||
TFLITE_DCHECK_GE(params.input1_offset, -1 * int8_max_value);
|
||||
|
@ -48,7 +48,7 @@ inline void Mul(const ArithmeticParams& params,
|
||||
params.quantized_activation_max);
|
||||
gemmlowp::ScopedProfilingLabel label("Mul/8bit");
|
||||
const int flat_size =
|
||||
MatchingFlatSize(input1_shape, input2_shape, output_shape);
|
||||
MatchingElementsSize(input1_shape, input2_shape, output_shape);
|
||||
|
||||
MulElementwise(flat_size, params, input1_data, input2_data, output_data);
|
||||
}
|
||||
@ -65,7 +65,7 @@ inline void Mul(const ArithmeticParams& params,
|
||||
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
|
||||
|
||||
const int flat_size =
|
||||
MatchingFlatSize(input1_shape, input2_shape, output_shape);
|
||||
MatchingElementsSize(input1_shape, input2_shape, output_shape);
|
||||
|
||||
for (int i = 0; i < flat_size; i++) {
|
||||
// F0 uses 0 integer bits, range [-1, 1].
|
||||
|
@ -350,7 +350,7 @@ inline void Mul(const ArithmeticParams& params,
|
||||
GetActivationParams(params, &output_activation_min, &output_activation_max);
|
||||
|
||||
const int flat_size =
|
||||
MatchingFlatSize(input1_shape, input2_shape, output_shape);
|
||||
MatchingElementsSize(input1_shape, input2_shape, output_shape);
|
||||
for (int i = 0; i < flat_size; ++i) {
|
||||
output_data[i] = ActivationFunctionWithMinMax(
|
||||
input1_data[i] * input2_data[i], output_activation_min,
|
||||
@ -444,7 +444,7 @@ inline void Mul(const ArithmeticParams& params,
|
||||
params.quantized_activation_max);
|
||||
gemmlowp::ScopedProfilingLabel label("Mul/8bit");
|
||||
const int flat_size =
|
||||
MatchingFlatSize(input1_shape, input2_shape, output_shape);
|
||||
MatchingElementsSize(input1_shape, input2_shape, output_shape);
|
||||
|
||||
MulElementwise(flat_size, params, input1_data, input2_data, output_data);
|
||||
}
|
||||
@ -551,7 +551,7 @@ inline void Mul(const ArithmeticParams& params,
|
||||
gemmlowp::ScopedProfilingLabel label("Mul/Int16");
|
||||
|
||||
const int flat_size =
|
||||
MatchingFlatSize(input1_shape, input2_shape, output_shape);
|
||||
MatchingElementsSize(input1_shape, input2_shape, output_shape);
|
||||
|
||||
for (int i = 0; i < flat_size; i++) {
|
||||
// F0 uses 0 integer bits, range [-1, 1].
|
||||
@ -574,7 +574,7 @@ inline void Mul(const ArithmeticParams& params,
|
||||
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
|
||||
|
||||
const int flat_size =
|
||||
MatchingFlatSize(input1_shape, input2_shape, output_shape);
|
||||
MatchingElementsSize(input1_shape, input2_shape, output_shape);
|
||||
|
||||
for (int i = 0; i < flat_size; i++) {
|
||||
// F0 uses 0 integer bits, range [-1, 1].
|
||||
@ -655,7 +655,7 @@ inline void Div(const ArithmeticParams& params,
|
||||
GetActivationParams(params, &output_activation_min, &output_activation_max);
|
||||
|
||||
const int flat_size =
|
||||
MatchingFlatSize(input1_shape, input2_shape, output_shape);
|
||||
MatchingElementsSize(input1_shape, input2_shape, output_shape);
|
||||
for (int i = 0; i < flat_size; ++i) {
|
||||
output_data[i] = ActivationFunctionWithMinMax(
|
||||
input1_data[i] / input2_data[i], output_activation_min,
|
||||
@ -706,7 +706,7 @@ inline void Div(const ArithmeticParams& params,
|
||||
params.quantized_activation_max);
|
||||
gemmlowp::ScopedProfilingLabel label("Div/8bit");
|
||||
const int flat_size =
|
||||
MatchingFlatSize(input1_shape, input2_shape, output_shape);
|
||||
MatchingElementsSize(input1_shape, input2_shape, output_shape);
|
||||
|
||||
DivElementwise(flat_size, params, input1_data, input2_data, output_data);
|
||||
}
|
||||
@ -779,7 +779,7 @@ inline void SubNonBroadcast(const ArithmeticParams& params,
|
||||
const RuntimeShape& output_shape,
|
||||
float* output_data) {
|
||||
const int flat_size =
|
||||
MatchingFlatSize(input1_shape, input2_shape, output_shape);
|
||||
MatchingElementsSize(input1_shape, input2_shape, output_shape);
|
||||
for (int i = 0; i < flat_size; ++i) {
|
||||
output_data[i] = ActivationFunctionWithMinMax(
|
||||
input1_data[i] - input2_data[i], params.float_activation_min,
|
||||
@ -795,7 +795,7 @@ inline void SubNonBroadcast(const ArithmeticParams& params,
|
||||
const RuntimeShape& output_shape,
|
||||
int32* output_data) {
|
||||
const int flat_size =
|
||||
MatchingFlatSize(input1_shape, input2_shape, output_shape);
|
||||
MatchingElementsSize(input1_shape, input2_shape, output_shape);
|
||||
for (int i = 0; i < flat_size; ++i) {
|
||||
output_data[i] = ActivationFunctionWithMinMax(
|
||||
input1_data[i] - input2_data[i], params.quantized_activation_min,
|
||||
@ -1043,7 +1043,7 @@ inline void SubWithActivation(const ArithmeticParams& params,
|
||||
int32* output_data) {
|
||||
gemmlowp::ScopedProfilingLabel label("SubWithActivation");
|
||||
const int flat_size =
|
||||
MatchingFlatSize(input1_shape, input2_shape, input2_shape);
|
||||
MatchingElementsSize(input1_shape, input2_shape, output_shape);
|
||||
for (int i = 0; i < flat_size; ++i) {
|
||||
output_data[i] = ActivationFunctionWithMinMax(
|
||||
input1_data[i] - input2_data[i], params.quantized_activation_min,
|
||||
@ -1059,7 +1059,7 @@ inline void SubWithActivation(const ArithmeticParams& params,
|
||||
const RuntimeShape& output_shape,
|
||||
float* output_data) {
|
||||
const int flat_size =
|
||||
MatchingFlatSize(input1_shape, input2_shape, input2_shape);
|
||||
MatchingElementsSize(input1_shape, input2_shape, output_shape);
|
||||
for (int i = 0; i < flat_size; ++i) {
|
||||
output_data[i] = ActivationFunctionWithMinMax(
|
||||
input1_data[i] - input2_data[i], params.float_activation_min,
|
||||
@ -1074,7 +1074,7 @@ inline void Sub16(const ArithmeticParams& params,
|
||||
gemmlowp::ScopedProfilingLabel label("Sub/Int16");
|
||||
const int input1_shift = params.input1_shift;
|
||||
const int flat_size =
|
||||
MatchingFlatSize(output_shape, input1_shape, input2_shape);
|
||||
MatchingElementsSize(input1_shape, input2_shape, output_shape);
|
||||
const int16 output_activation_min = params.quantized_activation_min;
|
||||
const int16 output_activation_max = params.quantized_activation_max;
|
||||
|
||||
|
@ -457,6 +457,25 @@ inline int RequiredBufferSizeForDims(const Dims<4>& dims) {
|
||||
return FlatSize(dims);
|
||||
}
|
||||
|
||||
inline int MatchingElementsSize(const RuntimeShape& shape,
|
||||
const RuntimeShape& check_shape_0) {
|
||||
const int size_1 = shape.FlatSize();
|
||||
const int size_2 = check_shape_0.FlatSize();
|
||||
TFLITE_CHECK_EQ(size_1, size_2);
|
||||
return size_1;
|
||||
}
|
||||
|
||||
inline int MatchingElementsSize(const RuntimeShape& shape,
|
||||
const RuntimeShape& check_shape_0,
|
||||
const RuntimeShape& check_shape_1) {
|
||||
const int size_1 = shape.FlatSize();
|
||||
const int size_2 = check_shape_0.FlatSize();
|
||||
const int size_3 = check_shape_1.FlatSize();
|
||||
TFLITE_CHECK_EQ(size_1, size_2);
|
||||
TFLITE_CHECK_EQ(size_2, size_3);
|
||||
return size_1;
|
||||
}
|
||||
|
||||
// Flat size calculation, checking that dimensions match with one or more other
|
||||
// arrays.
|
||||
inline int MatchingFlatSize(const RuntimeShape& shape,
|
||||
|
@ -110,8 +110,6 @@ void EvalMul(TfLiteContext* context, TfLiteNode* node, TfLiteMulParams* params,
|
||||
const OpData* data, const TfLiteTensor* input1,
|
||||
const TfLiteTensor* input2, TfLiteTensor* output) {
|
||||
tflite::ArithmeticParams op_params;
|
||||
// requires_flat_size_broadcast is used for BroadcastMul4DSlow.
|
||||
const bool requires_flat_size_broadcast = !HaveSameShapes(input1, input2);
|
||||
const bool need_broadcast = optimized_ops::ProcessBroadcastShapes(
|
||||
GetTensorShape(input1), GetTensorShape(input2), &op_params);
|
||||
#define TF_LITE_MUL(type, opname, data_type) \
|
||||
@ -127,13 +125,13 @@ void EvalMul(TfLiteContext* context, TfLiteNode* node, TfLiteMulParams* params,
|
||||
|
||||
if (output->type == kTfLiteInt32) {
|
||||
if (kernel_type == kReference) {
|
||||
if (requires_flat_size_broadcast) {
|
||||
if (need_broadcast) {
|
||||
TF_LITE_MUL(reference_ops, BroadcastMul4DSlow, int32_t);
|
||||
} else {
|
||||
TF_LITE_MUL(reference_ops, Mul, int32_t);
|
||||
}
|
||||
} else {
|
||||
if (requires_flat_size_broadcast) {
|
||||
if (need_broadcast) {
|
||||
TF_LITE_MUL(optimized_ops, BroadcastMul4DSlow, int32_t);
|
||||
} else {
|
||||
TF_LITE_MUL(optimized_ops, Mul, int32_t);
|
||||
@ -141,7 +139,7 @@ void EvalMul(TfLiteContext* context, TfLiteNode* node, TfLiteMulParams* params,
|
||||
}
|
||||
} else if (output->type == kTfLiteFloat32) {
|
||||
if (kernel_type == kReference) {
|
||||
if (requires_flat_size_broadcast) {
|
||||
if (need_broadcast) {
|
||||
TF_LITE_MUL(reference_ops, BroadcastMul4DSlow, float);
|
||||
} else {
|
||||
TF_LITE_MUL(reference_ops, Mul, float);
|
||||
@ -149,8 +147,6 @@ void EvalMul(TfLiteContext* context, TfLiteNode* node, TfLiteMulParams* params,
|
||||
} else {
|
||||
if (need_broadcast) {
|
||||
TF_LITE_MUL(optimized_ops, BroadcastMulFivefold, float);
|
||||
} else if (requires_flat_size_broadcast) {
|
||||
TF_LITE_MUL(optimized_ops, BroadcastMul4DSlow, float);
|
||||
} else {
|
||||
TF_LITE_MUL(optimized_ops, Mul, float);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user