remove unnecessary requires_flat_size_broadcast check.

PiperOrigin-RevId: 275175433
Change-Id: I084990420ed4a7f3cbb32ac554d0b2f65537be78
This commit is contained in:
Renjie Liu 2019-10-16 21:01:14 -07:00 committed by TensorFlower Gardener
parent d0de62189a
commit 0a0dbde0a4
11 changed files with 87 additions and 50 deletions
tensorflow/lite/kernels

View File

@ -177,8 +177,6 @@ void EvalAdd(TfLiteContext* context, TfLiteNode* node, TfLiteAddParams* params,
const OpData* data, const TfLiteTensor* input1,
const TfLiteTensor* input2, TfLiteTensor* output) {
tflite::ArithmeticParams op_params;
// requires_flat_size_broadcast is used for BroadcastAdd4DSlow.
const bool requires_flat_size_broadcast = !HaveSameShapes(input1, input2);
const bool need_broadcast = optimized_ops::ProcessBroadcastShapes(
GetTensorShape(input1), GetTensorShape(input2), &op_params);
#define TF_LITE_ADD(type, opname, data_type) \
@ -193,13 +191,13 @@ void EvalAdd(TfLiteContext* context, TfLiteNode* node, TfLiteAddParams* params,
GetTensorData<data_type>(output))
if (output->type == kTfLiteInt32) {
if (kernel_type == kReference) {
if (requires_flat_size_broadcast) {
if (need_broadcast) {
TF_LITE_ADD(reference_ops, BroadcastAdd4DSlow, int32_t);
} else {
TF_LITE_ADD(reference_ops, Add, int32_t);
}
} else {
if (requires_flat_size_broadcast) {
if (need_broadcast) {
TF_LITE_ADD(optimized_ops, BroadcastAdd4DSlow, int32_t);
} else {
TF_LITE_ADD(optimized_ops, Add, int32_t);
@ -207,7 +205,7 @@ void EvalAdd(TfLiteContext* context, TfLiteNode* node, TfLiteAddParams* params,
}
} else if (output->type == kTfLiteFloat32) {
if (kernel_type == kReference) {
if (requires_flat_size_broadcast) {
if (need_broadcast) {
TF_LITE_ADD(reference_ops, BroadcastAdd4DSlow, float);
} else {
TF_LITE_ADD(reference_ops, Add, float);
@ -215,8 +213,6 @@ void EvalAdd(TfLiteContext* context, TfLiteNode* node, TfLiteAddParams* params,
} else {
if (need_broadcast) {
TF_LITE_ADD(optimized_ops, BroadcastAddFivefold, float);
} else if (requires_flat_size_broadcast) {
TF_LITE_ADD(optimized_ops, BroadcastAdd4DSlow, float);
} else {
TF_LITE_ADD(optimized_ops, Add, float);
}

View File

@ -190,6 +190,24 @@ TEST(IntegerAddOpModel, WithBroadcast) {
}
}
TEST(IntegerAddOpModel, Int32MultiDimBroadcast) {
IntegerAddOpModel m({TensorType_INT32, {1, 2}}, {TensorType_INT32, {2, 1}},
{TensorType_INT32, {}}, ActivationFunctionType_NONE);
m.PopulateTensor<int32_t>(m.input1(), {3, 5});
m.PopulateTensor<int32_t>(m.input2(), {1, 4});
m.Invoke();
EXPECT_THAT(m.GetOutput(), ElementsAreArray({4, 6, 7, 9}));
}
TEST(IntegerAddOpModel, Float32MultiDimBroadcast) {
FloatAddOpModel m({TensorType_FLOAT32, {1, 2}}, {TensorType_FLOAT32, {2, 1}},
{TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
m.PopulateTensor<float>(m.input1(), {3, 5});
m.PopulateTensor<float>(m.input2(), {1, 4});
m.Invoke();
EXPECT_THAT(m.GetOutput(), ElementsAreArray({4, 6, 7, 9}));
}
template <TensorType tensor_type, typename integer_dtype>
void QuantizedTestsNoActivation() {
float kQuantizedTolerance = GetTolerance(-1.0, 1.0);

View File

@ -221,7 +221,7 @@ inline void Add(const ArithmeticParams& params,
params.quantized_activation_max);
gemmlowp::ScopedProfilingLabel label("AddInt8/8bit");
const int flat_size =
MatchingFlatSize(input1_shape, input2_shape, output_shape);
MatchingElementsSize(input1_shape, input2_shape, output_shape);
TFLITE_DCHECK_GT(params.input1_offset, -256);
TFLITE_DCHECK_GT(params.input2_offset, -256);

View File

@ -175,7 +175,7 @@ inline void Mul(const ArithmeticParams& params,
params.quantized_activation_max);
gemmlowp::ScopedProfilingLabel label("MulInt8/8bit");
const int flat_size =
MatchingFlatSize(input1_shape, input2_shape, output_shape);
MatchingElementsSize(input1_shape, input2_shape, output_shape);
MulElementwise(flat_size, params, input1_data, input2_data, output_data);
}

View File

@ -1543,7 +1543,7 @@ inline void Add(const ArithmeticParams& params,
const RuntimeShape& output_shape, float* output_data) {
gemmlowp::ScopedProfilingLabel label("Add");
const int flat_size =
MatchingFlatSize(input1_shape, input2_shape, output_shape);
MatchingElementsSize(input1_shape, input2_shape, output_shape);
AddElementwise(flat_size, params, input1_data, input2_data, output_data);
}
@ -1782,7 +1782,7 @@ inline void Add(const ArithmeticParams& params,
params.quantized_activation_max);
gemmlowp::ScopedProfilingLabel label("Add/8bit");
const int flat_size =
MatchingFlatSize(input1_shape, input2_shape, output_shape);
MatchingElementsSize(input1_shape, input2_shape, output_shape);
TFLITE_DCHECK_GT(params.input1_offset, -256);
TFLITE_DCHECK_GT(params.input2_offset, -256);
@ -1801,7 +1801,7 @@ inline void Add(const ArithmeticParams& params,
const int input1_shift = params.input1_shift;
const int flat_size =
MatchingFlatSize(output_shape, input1_shape, input2_shape);
MatchingElementsSize(input1_shape, input2_shape, output_shape);
const int16 output_activation_min = params.quantized_activation_min;
const int16 output_activation_max = params.quantized_activation_max;
@ -1846,8 +1846,10 @@ inline void Add(const ArithmeticParams& params,
auto scalar = input1_data[0];
output_map.array() = scalar + input2_map.array();
} else {
// Should not come here.
TFLITE_DCHECK(false);
reference_ops::BroadcastAdd4DSlow(params, input1_shape, input1_data,
input2_shape, input2_data, output_shape,
output_data);
return;
}
output_map = output_map.cwiseMax(params.quantized_activation_min);
output_map = output_map.cwiseMin(params.quantized_activation_max);
@ -2097,7 +2099,7 @@ inline void Mul(const ArithmeticParams& params,
gemmlowp::ScopedProfilingLabel label("Mul");
const int flat_size =
MatchingFlatSize(input1_shape, input2_shape, output_shape);
MatchingElementsSize(input1_shape, input2_shape, output_shape);
MulElementwise(flat_size, params, input1_data, input2_data, output_data);
}
@ -2108,7 +2110,7 @@ inline void Mul(const ArithmeticParams& params,
gemmlowp::ScopedProfilingLabel label("Mul/int32/activation");
const int flat_size =
MatchingFlatSize(input1_shape, input2_shape, output_shape);
MatchingElementsSize(input1_shape, input2_shape, output_shape);
const int32 output_activation_min = params.quantized_activation_min;
const int32 output_activation_max = params.quantized_activation_max;
for (int i = 0; i < flat_size; ++i) {
@ -2139,8 +2141,9 @@ inline void MulNoActivation(const ArithmeticParams& params,
auto scalar = input1_data[0];
output_map.array() = scalar * input2_map.array();
} else {
// Should not come here.
TFLITE_DCHECK(false);
reference_ops::BroadcastMul4DSlow(params, input1_shape, input1_data,
input2_shape, input2_data, output_shape,
output_data);
}
}
@ -2153,7 +2156,7 @@ inline void Mul(const ArithmeticParams& params,
// properly optimized version.
const int flat_size =
MatchingFlatSize(input1_shape, input2_shape, output_shape);
MatchingElementsSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
// F0 uses 0 integer bits, range [-1, 1].
@ -2178,7 +2181,7 @@ inline void Mul(const ArithmeticParams& params,
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
const int flat_size =
MatchingFlatSize(input1_shape, input2_shape, output_shape);
MatchingElementsSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
// F0 uses 0 integer bits, range [-1, 1].
@ -2342,6 +2345,8 @@ inline void MulSimpleBroadcast(int size, const ArithmeticParams& params,
}
// Broadcast mul that can often be used for inner loop of broadcast Mul.
// This function will handle scalar_value (LHS) * vector_values (RHS).
// Since it's a float function, input params does not matter here.
inline void MulSimpleBroadcast(int size, const ArithmeticParams& params,
const float broadcast_value,
const float* input2_data, float* output_data) {
@ -2380,7 +2385,7 @@ inline void Mul(const ArithmeticParams& params,
params.quantized_activation_max);
gemmlowp::ScopedProfilingLabel label("Mul/8bit");
const int flat_size =
MatchingFlatSize(input1_shape, input2_shape, output_shape);
MatchingElementsSize(input1_shape, input2_shape, output_shape);
MulElementwise(flat_size, params, input1_data, input2_data, output_data);
}
@ -2509,6 +2514,8 @@ inline void BroadcastMulFivefold(const ArithmeticParams& params,
for (int i1 = 0; i1 < y1; ++i1) {
input2_data_ptr = input2_data_reset;
for (int i2 = 0; i2 < y2; ++i2) {
// The input may be switched here, but the common parameters here
// do not matter as they will not influence the float math execution.
MulSimpleBroadcast(y3, params, *input1_data_ptr, input2_data_ptr,
output_data_ptr);
input2_data_ptr += y3;
@ -2652,7 +2659,7 @@ inline void SubNonBroadcast(const ArithmeticParams& params,
float* output_data) {
gemmlowp::ScopedProfilingLabel label("SubNonBroadcast");
const int flat_size =
MatchingFlatSize(input1_shape, input2_shape, output_shape);
MatchingElementsSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
output_data[i] = ActivationFunctionWithMinMax(
input1_data[i] - input2_data[i], params.float_activation_min,
@ -2669,7 +2676,7 @@ inline void SubWithActivation(const ArithmeticParams& params,
int32* output_data) {
gemmlowp::ScopedProfilingLabel label("SubWithActivation/int32");
const int flat_size =
MatchingFlatSize(input1_shape, input2_shape, input2_shape);
MatchingElementsSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
output_data[i] = ActivationFunctionWithMinMax(
input1_data[i] - input2_data[i], params.quantized_activation_min,
@ -2686,7 +2693,7 @@ inline void SubWithActivation(const ArithmeticParams& params,
float* output_data) {
gemmlowp::ScopedProfilingLabel label("SubWithActivation/float");
const int flat_size =
MatchingFlatSize(input1_shape, input2_shape, input2_shape);
MatchingElementsSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
output_data[i] = ActivationFunctionWithMinMax(
input1_data[i] - input2_data[i], params.float_activation_min,

View File

@ -28,7 +28,7 @@ inline void Add(const ArithmeticParams& params,
const RuntimeShape& input2_shape, const T* input2_data,
const RuntimeShape& output_shape, T* output_data) {
const int flat_size =
MatchingFlatSize(input1_shape, input2_shape, output_shape);
MatchingElementsSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
output_data[i] = ActivationFunctionWithMinMax(
input1_data[i] + input2_data[i], params.quantized_activation_min,
@ -40,8 +40,9 @@ inline void Add(const ArithmeticParams& params,
const RuntimeShape& input1_shape, const float* input1_data,
const RuntimeShape& input2_shape, const float* input2_data,
const RuntimeShape& output_shape, float* output_data) {
const int size = MatchingFlatSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < size; i++) {
const int flat_size =
MatchingElementsSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
auto x = input1_data[i] + input2_data[i];
output_data[i] = ActivationFunctionWithMinMax(
x, params.float_activation_min, params.float_activation_max);
@ -122,7 +123,7 @@ inline void Add(const ArithmeticParams& params,
TFLITE_DCHECK_LE(params.quantized_activation_min,
params.quantized_activation_max);
const int flat_size =
MatchingFlatSize(input1_shape, input2_shape, output_shape);
MatchingElementsSize(input1_shape, input2_shape, output_shape);
TFLITE_DCHECK_GT(params.input1_offset, -256);
TFLITE_DCHECK_GT(params.input2_offset, -256);
@ -140,7 +141,7 @@ inline void Add(const ArithmeticParams& params,
const int input1_shift = params.input1_shift;
const int flat_size =
MatchingFlatSize(output_shape, input1_shape, input2_shape);
MatchingElementsSize(input1_shape, input2_shape, output_shape);
const int16 output_activation_min = params.quantized_activation_min;
const int16 output_activation_max = params.quantized_activation_max;

View File

@ -64,7 +64,7 @@ inline void Add(const ArithmeticParams& params,
TFLITE_DCHECK_LE(params.quantized_activation_min,
params.quantized_activation_max);
const int flat_size =
MatchingFlatSize(input1_shape, input2_shape, output_shape);
MatchingElementsSize(input1_shape, input2_shape, output_shape);
const int32_t int8_max_value = std::numeric_limits<int8_t>::max();
TFLITE_DCHECK_GE(params.input1_offset, -1 * int8_max_value);

View File

@ -48,7 +48,7 @@ inline void Mul(const ArithmeticParams& params,
params.quantized_activation_max);
gemmlowp::ScopedProfilingLabel label("Mul/8bit");
const int flat_size =
MatchingFlatSize(input1_shape, input2_shape, output_shape);
MatchingElementsSize(input1_shape, input2_shape, output_shape);
MulElementwise(flat_size, params, input1_data, input2_data, output_data);
}
@ -65,7 +65,7 @@ inline void Mul(const ArithmeticParams& params,
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
const int flat_size =
MatchingFlatSize(input1_shape, input2_shape, output_shape);
MatchingElementsSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
// F0 uses 0 integer bits, range [-1, 1].

View File

@ -350,7 +350,7 @@ inline void Mul(const ArithmeticParams& params,
GetActivationParams(params, &output_activation_min, &output_activation_max);
const int flat_size =
MatchingFlatSize(input1_shape, input2_shape, output_shape);
MatchingElementsSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
output_data[i] = ActivationFunctionWithMinMax(
input1_data[i] * input2_data[i], output_activation_min,
@ -444,7 +444,7 @@ inline void Mul(const ArithmeticParams& params,
params.quantized_activation_max);
gemmlowp::ScopedProfilingLabel label("Mul/8bit");
const int flat_size =
MatchingFlatSize(input1_shape, input2_shape, output_shape);
MatchingElementsSize(input1_shape, input2_shape, output_shape);
MulElementwise(flat_size, params, input1_data, input2_data, output_data);
}
@ -551,7 +551,7 @@ inline void Mul(const ArithmeticParams& params,
gemmlowp::ScopedProfilingLabel label("Mul/Int16");
const int flat_size =
MatchingFlatSize(input1_shape, input2_shape, output_shape);
MatchingElementsSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
// F0 uses 0 integer bits, range [-1, 1].
@ -574,7 +574,7 @@ inline void Mul(const ArithmeticParams& params,
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
const int flat_size =
MatchingFlatSize(input1_shape, input2_shape, output_shape);
MatchingElementsSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
// F0 uses 0 integer bits, range [-1, 1].
@ -655,7 +655,7 @@ inline void Div(const ArithmeticParams& params,
GetActivationParams(params, &output_activation_min, &output_activation_max);
const int flat_size =
MatchingFlatSize(input1_shape, input2_shape, output_shape);
MatchingElementsSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
output_data[i] = ActivationFunctionWithMinMax(
input1_data[i] / input2_data[i], output_activation_min,
@ -706,7 +706,7 @@ inline void Div(const ArithmeticParams& params,
params.quantized_activation_max);
gemmlowp::ScopedProfilingLabel label("Div/8bit");
const int flat_size =
MatchingFlatSize(input1_shape, input2_shape, output_shape);
MatchingElementsSize(input1_shape, input2_shape, output_shape);
DivElementwise(flat_size, params, input1_data, input2_data, output_data);
}
@ -779,7 +779,7 @@ inline void SubNonBroadcast(const ArithmeticParams& params,
const RuntimeShape& output_shape,
float* output_data) {
const int flat_size =
MatchingFlatSize(input1_shape, input2_shape, output_shape);
MatchingElementsSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
output_data[i] = ActivationFunctionWithMinMax(
input1_data[i] - input2_data[i], params.float_activation_min,
@ -795,7 +795,7 @@ inline void SubNonBroadcast(const ArithmeticParams& params,
const RuntimeShape& output_shape,
int32* output_data) {
const int flat_size =
MatchingFlatSize(input1_shape, input2_shape, output_shape);
MatchingElementsSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
output_data[i] = ActivationFunctionWithMinMax(
input1_data[i] - input2_data[i], params.quantized_activation_min,
@ -1043,7 +1043,7 @@ inline void SubWithActivation(const ArithmeticParams& params,
int32* output_data) {
gemmlowp::ScopedProfilingLabel label("SubWithActivation");
const int flat_size =
MatchingFlatSize(input1_shape, input2_shape, input2_shape);
MatchingElementsSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
output_data[i] = ActivationFunctionWithMinMax(
input1_data[i] - input2_data[i], params.quantized_activation_min,
@ -1059,7 +1059,7 @@ inline void SubWithActivation(const ArithmeticParams& params,
const RuntimeShape& output_shape,
float* output_data) {
const int flat_size =
MatchingFlatSize(input1_shape, input2_shape, input2_shape);
MatchingElementsSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
output_data[i] = ActivationFunctionWithMinMax(
input1_data[i] - input2_data[i], params.float_activation_min,
@ -1074,7 +1074,7 @@ inline void Sub16(const ArithmeticParams& params,
gemmlowp::ScopedProfilingLabel label("Sub/Int16");
const int input1_shift = params.input1_shift;
const int flat_size =
MatchingFlatSize(output_shape, input1_shape, input2_shape);
MatchingElementsSize(input1_shape, input2_shape, output_shape);
const int16 output_activation_min = params.quantized_activation_min;
const int16 output_activation_max = params.quantized_activation_max;

View File

@ -457,6 +457,25 @@ inline int RequiredBufferSizeForDims(const Dims<4>& dims) {
return FlatSize(dims);
}
inline int MatchingElementsSize(const RuntimeShape& shape,
const RuntimeShape& check_shape_0) {
const int size_1 = shape.FlatSize();
const int size_2 = check_shape_0.FlatSize();
TFLITE_CHECK_EQ(size_1, size_2);
return size_1;
}
inline int MatchingElementsSize(const RuntimeShape& shape,
const RuntimeShape& check_shape_0,
const RuntimeShape& check_shape_1) {
const int size_1 = shape.FlatSize();
const int size_2 = check_shape_0.FlatSize();
const int size_3 = check_shape_1.FlatSize();
TFLITE_CHECK_EQ(size_1, size_2);
TFLITE_CHECK_EQ(size_2, size_3);
return size_1;
}
// Flat size calculation, checking that dimensions match with one or more other
// arrays.
inline int MatchingFlatSize(const RuntimeShape& shape,

View File

@ -110,8 +110,6 @@ void EvalMul(TfLiteContext* context, TfLiteNode* node, TfLiteMulParams* params,
const OpData* data, const TfLiteTensor* input1,
const TfLiteTensor* input2, TfLiteTensor* output) {
tflite::ArithmeticParams op_params;
// requires_flat_size_broadcast is used for BroadcastMul4DSlow.
const bool requires_flat_size_broadcast = !HaveSameShapes(input1, input2);
const bool need_broadcast = optimized_ops::ProcessBroadcastShapes(
GetTensorShape(input1), GetTensorShape(input2), &op_params);
#define TF_LITE_MUL(type, opname, data_type) \
@ -127,13 +125,13 @@ void EvalMul(TfLiteContext* context, TfLiteNode* node, TfLiteMulParams* params,
if (output->type == kTfLiteInt32) {
if (kernel_type == kReference) {
if (requires_flat_size_broadcast) {
if (need_broadcast) {
TF_LITE_MUL(reference_ops, BroadcastMul4DSlow, int32_t);
} else {
TF_LITE_MUL(reference_ops, Mul, int32_t);
}
} else {
if (requires_flat_size_broadcast) {
if (need_broadcast) {
TF_LITE_MUL(optimized_ops, BroadcastMul4DSlow, int32_t);
} else {
TF_LITE_MUL(optimized_ops, Mul, int32_t);
@ -141,7 +139,7 @@ void EvalMul(TfLiteContext* context, TfLiteNode* node, TfLiteMulParams* params,
}
} else if (output->type == kTfLiteFloat32) {
if (kernel_type == kReference) {
if (requires_flat_size_broadcast) {
if (need_broadcast) {
TF_LITE_MUL(reference_ops, BroadcastMul4DSlow, float);
} else {
TF_LITE_MUL(reference_ops, Mul, float);
@ -149,8 +147,6 @@ void EvalMul(TfLiteContext* context, TfLiteNode* node, TfLiteMulParams* params,
} else {
if (need_broadcast) {
TF_LITE_MUL(optimized_ops, BroadcastMulFivefold, float);
} else if (requires_flat_size_broadcast) {
TF_LITE_MUL(optimized_ops, BroadcastMul4DSlow, float);
} else {
TF_LITE_MUL(optimized_ops, Mul, float);
}