From 9a662b14eaa85ce466be82abd758a636c1eaa756 Mon Sep 17 00:00:00 2001 From: Nat Jeffries Date: Mon, 21 Sep 2020 15:58:41 -0700 Subject: [PATCH] Add support for non-4D input tensors in MEAN operator. Fixes:#43332 PiperOrigin-RevId: 332954480 Change-Id: Ib2291459da50d5458bb34a139a2a10db6a98d0df --- tensorflow/lite/micro/kernels/reduce.cc | 27 +-- tensorflow/lite/micro/kernels/reduce_test.cc | 241 +++++++++++++++---- 2 files changed, 204 insertions(+), 64 deletions(-) diff --git a/tensorflow/lite/micro/kernels/reduce.cc b/tensorflow/lite/micro/kernels/reduce.cc index 0ee2b83923a..8c60269cb02 100644 --- a/tensorflow/lite/micro/kernels/reduce.cc +++ b/tensorflow/lite/micro/kernels/reduce.cc @@ -150,21 +150,17 @@ TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) { tflite::MeanParams op_params; ResolveAxis(tflite::micro::GetTensorData(axis), num_axis, &op_params); - // TODO(b/146571391): Support only 4D Input and 2D Axis for Mean until - // scratch tensor allocation has been implemented in (b/132070898) - bool is_valid_inputs = (input->dims->size == 4 && op_params.axis_count == 2 && - ((op_params.axis[0] == 1 && op_params.axis[1] == 2) || - (op_params.axis[0] == 2 && op_params.axis[1] == 1))); - TF_LITE_ENSURE_MSG( - context, is_valid_inputs == true, - "Number of Input " - "dimensions != 4 OR the Axis is not either [1, 2] or [2, 1]"); + + // Special case mean implementation exists for 4D mean across axes 1 and 2. + bool special_case_4d_axes_1_and_2 = + input->dims->size == 4 && op_params.axis_count == 2 && + ((op_params.axis[0] == 1 && op_params.axis[1] == 2) || + (op_params.axis[0] == 2 && op_params.axis[1] == 1)); + switch (input->type) { case kTfLiteFloat32: { - // TODO(b/139102329): Handle the below special case in the combined - // reference method. // Defer to specialized implementation for 4D Mean across axes 1 & 2. - if (params->keep_dims) { + if (params->keep_dims && special_case_4d_axes_1_and_2) { reference_ops::Mean(op_params, tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(output), @@ -182,7 +178,8 @@ TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) { } } break; case kTfLiteInt8: { - if (params->keep_dims) { + // Defer to specialized implementation for 4D Mean across axes 1 & 2. + if (params->keep_dims && special_case_4d_axes_1_and_2) { reference_integer_ops::Mean( op_params, op_data->multiplier, op_data->shift, tflite::micro::GetTensorShape(input), @@ -217,7 +214,8 @@ TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) { } } break; case kTfLiteUInt8: { - if (params->keep_dims) { + // Defer to specialized implementation for 4D Mean across axes 1 & 2. + if (params->keep_dims && special_case_4d_axes_1_and_2) { reference_ops::Mean(op_params, tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), op_data->input_zp, op_data->input_scale, @@ -253,7 +251,6 @@ TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) { } } break; default: - // TODO(b/144955155): Support uint8_t(b/144955155) and int8_t(b/144955018) TF_LITE_ENSURE_MSG(context, false, "Currently, only float32, int8 or uint8 input type " "is supported."); diff --git a/tensorflow/lite/micro/kernels/reduce_test.cc b/tensorflow/lite/micro/kernels/reduce_test.cc index 26a73b6c8ce..e7be6569ca7 100644 --- a/tensorflow/lite/micro/kernels/reduce_test.cc +++ b/tensorflow/lite/micro/kernels/reduce_test.cc @@ -24,20 +24,45 @@ namespace tflite { namespace testing { namespace { -// Common inputs and outputs. +// Common 2D inputs, outputs and axis. +static const int kInputElements2D = 8; +static const int kInputShape2D[] = {2, 2, 4}; +static const float kInputData2D[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}; + +static const int kAxisShape2D[] = {1, 1}; +static const int32_t kAxisData2D[] = {1}; + +static const int kOutputElements2D = 2; +static const int kOutputShape2D[] = {2, 1, 2}; +static const float kGoldenData2D[] = {2.5, 6.5}; + +// Common 3D inputs, outputs and axis. +static const int kInputElements3D = 8; +static const int kInputShape3D[] = {3, 2, 2, 2}; +static const float kInputData3D[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}; + +static const int kAxisShape3D[] = {1, 2}; +static const int32_t kAxisData3D[] = {1, 2}; + +static const int kOutputElements3D = 2; +static const int kOutputShape3D[] = {2, 1, 2}; +static const float kGoldenData3D[] = {2.5, 6.5}; + +// Common 4D inputs, outputs and axis. static const int kInputElements4D = 24; static const int kInputShape4D[] = {4, 2, 2, 3, 2}; static const float kInputData4D[] = { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; -// static const int kAxisElements = 3; -static const int kAxisShape[] = {1, 2}; -static const int32_t kAxisData[] = {1, 2}; +static const int kAxisShape4D[] = {1, 2}; +static const int32_t kAxisData4D[] = {1, 2}; -static const int kOutputElements = 4; -static const int kOutputShape[] = {4, 2, 1, 1, 2}; -static const float kGoldenData[] = {6, 7, 18, 19}; +static const int kOutputElements4D = 4; +static const int kOutputShape4D[] = {4, 2, 1, 1, 2}; +static const float kGoldenData4D[] = {6, 7, 18, 19}; + +// Axis shape and contents are independent of input / output dimensions. template TfLiteStatus ValidateReduceGoldens(TfLiteTensor* tensors, int tensors_size, @@ -202,8 +227,124 @@ void TestMeanOpQuantized(const int* input_dims_data, const float* input_data, TF_LITE_MICRO_TESTS_BEGIN +TF_LITE_MICRO_TEST(MeanFloat2DKeepDims) { + float output_data[tflite::testing::kOutputElements2D]; + + TfLiteReducerParams params = {true}; + + tflite::testing::TestMeanFloatInput4D( + tflite::testing::kInputShape2D, tflite::testing::kInputData2D, + tflite::testing::kAxisShape2D, tflite::testing::kAxisData2D, + tflite::testing::kOutputShape2D, tflite::testing::kGoldenData2D, + output_data, ¶ms); +} + +TF_LITE_MICRO_TEST(MeanInt82DKeepDims) { + int8_t expected_output_data_quant[tflite::testing::kOutputElements2D]; + int8_t output_data_quant[tflite::testing::kOutputElements2D]; + int8_t input_data_quant[tflite::testing::kInputElements2D]; + + float input_scale = 0.5f; + int input_zero_point = 0; + float output_scale = 0.5f; + int output_zero_point = 0; + + TfLiteReducerParams params = { + true // keep_dims + }; + + tflite::testing::TestMeanOpQuantized( + tflite::testing::kInputShape2D, tflite::testing::kInputData2D, + input_data_quant, input_scale, input_zero_point, + tflite::testing::kAxisShape2D, tflite::testing::kAxisData2D, + tflite::testing::kOutputShape2D, tflite::testing::kGoldenData2D, + output_data_quant, expected_output_data_quant, output_scale, + output_zero_point, ¶ms); +} + +TF_LITE_MICRO_TEST(MeanUInt82DKeepDims) { + uint8_t expected_output_data_quant[tflite::testing::kOutputElements2D]; + uint8_t output_data_quant[tflite::testing::kOutputElements2D]; + uint8_t input_data_quant[tflite::testing::kInputElements2D]; + + float input_scale = 0.5f; + int input_zero_point = 128; + float output_scale = 0.5f; + int output_zero_point = 128; + + TfLiteReducerParams params = { + true // keep_dims + }; + + tflite::testing::TestMeanOpQuantized( + tflite::testing::kInputShape2D, tflite::testing::kInputData2D, + input_data_quant, input_scale, input_zero_point, + tflite::testing::kAxisShape2D, tflite::testing::kAxisData2D, + tflite::testing::kOutputShape2D, tflite::testing::kGoldenData2D, + output_data_quant, expected_output_data_quant, output_scale, + output_zero_point, ¶ms); +} + +TF_LITE_MICRO_TEST(MeanFloat3DKeepDims) { + float output_data[tflite::testing::kOutputElements3D]; + + TfLiteReducerParams params = {true}; + + tflite::testing::TestMeanFloatInput4D( + tflite::testing::kInputShape3D, tflite::testing::kInputData3D, + tflite::testing::kAxisShape3D, tflite::testing::kAxisData3D, + tflite::testing::kOutputShape3D, tflite::testing::kGoldenData3D, + output_data, ¶ms); +} + +TF_LITE_MICRO_TEST(MeanInt83DKeepDims) { + int8_t expected_output_data_quant[tflite::testing::kOutputElements3D]; + int8_t output_data_quant[tflite::testing::kOutputElements3D]; + int8_t input_data_quant[tflite::testing::kInputElements3D]; + + float input_scale = 0.5f; + int input_zero_point = 0; + float output_scale = 0.5f; + int output_zero_point = 0; + + TfLiteReducerParams params = { + true // keep_dims + }; + + tflite::testing::TestMeanOpQuantized( + tflite::testing::kInputShape3D, tflite::testing::kInputData3D, + input_data_quant, input_scale, input_zero_point, + tflite::testing::kAxisShape3D, tflite::testing::kAxisData3D, + tflite::testing::kOutputShape3D, tflite::testing::kGoldenData3D, + output_data_quant, expected_output_data_quant, output_scale, + output_zero_point, ¶ms); +} + +TF_LITE_MICRO_TEST(MeanUInt83DKeepDims) { + uint8_t expected_output_data_quant[tflite::testing::kOutputElements3D]; + uint8_t output_data_quant[tflite::testing::kOutputElements3D]; + uint8_t input_data_quant[tflite::testing::kInputElements3D]; + + float input_scale = 0.5f; + int input_zero_point = 138; + float output_scale = 0.5f; + int output_zero_point = 138; + + TfLiteReducerParams params = { + true // keep_dims + }; + + tflite::testing::TestMeanOpQuantized( + tflite::testing::kInputShape3D, tflite::testing::kInputData3D, + input_data_quant, input_scale, input_zero_point, + tflite::testing::kAxisShape3D, tflite::testing::kAxisData3D, + tflite::testing::kOutputShape3D, tflite::testing::kGoldenData3D, + output_data_quant, expected_output_data_quant, output_scale, + output_zero_point, ¶ms); +} + TF_LITE_MICRO_TEST(MeanFloat4DKeepDims) { - float output_data[tflite::testing::kOutputElements]; + float output_data[tflite::testing::kOutputElements4D]; TfLiteReducerParams params = { true // keep_dims @@ -211,14 +352,14 @@ TF_LITE_MICRO_TEST(MeanFloat4DKeepDims) { tflite::testing::TestMeanFloatInput4D( tflite::testing::kInputShape4D, tflite::testing::kInputData4D, - tflite::testing::kAxisShape, tflite::testing::kAxisData, - tflite::testing::kOutputShape, tflite::testing::kGoldenData, output_data, - ¶ms); + tflite::testing::kAxisShape4D, tflite::testing::kAxisData4D, + tflite::testing::kOutputShape4D, tflite::testing::kGoldenData4D, + output_data, ¶ms); } TF_LITE_MICRO_TEST(MeanInt84DKeepDims) { - int8_t expected_output_data_quant[tflite::testing::kOutputElements]; - int8_t output_data_quant[tflite::testing::kOutputElements]; + int8_t expected_output_data_quant[tflite::testing::kOutputElements4D]; + int8_t output_data_quant[tflite::testing::kOutputElements4D]; int8_t input_data_quant[tflite::testing::kInputElements4D]; float input_scale = 0.5f; @@ -233,15 +374,15 @@ TF_LITE_MICRO_TEST(MeanInt84DKeepDims) { tflite::testing::TestMeanOpQuantized( tflite::testing::kInputShape4D, tflite::testing::kInputData4D, input_data_quant, input_scale, input_zero_point, - tflite::testing::kAxisShape, tflite::testing::kAxisData, - tflite::testing::kOutputShape, tflite::testing::kGoldenData, + tflite::testing::kAxisShape4D, tflite::testing::kAxisData4D, + tflite::testing::kOutputShape4D, tflite::testing::kGoldenData4D, output_data_quant, expected_output_data_quant, output_scale, output_zero_point, ¶ms); } TF_LITE_MICRO_TEST(MeanUInt84DKeepDims) { - uint8_t expected_output_data_quant[tflite::testing::kOutputElements]; - uint8_t output_data_quant[tflite::testing::kOutputElements]; + uint8_t expected_output_data_quant[tflite::testing::kOutputElements4D]; + uint8_t output_data_quant[tflite::testing::kOutputElements4D]; uint8_t input_data_quant[tflite::testing::kInputElements4D]; float input_scale = 0.5f; @@ -256,31 +397,31 @@ TF_LITE_MICRO_TEST(MeanUInt84DKeepDims) { tflite::testing::TestMeanOpQuantized( tflite::testing::kInputShape4D, tflite::testing::kInputData4D, input_data_quant, input_scale, input_zero_point, - tflite::testing::kAxisShape, tflite::testing::kAxisData, - tflite::testing::kOutputShape, tflite::testing::kGoldenData, + tflite::testing::kAxisShape4D, tflite::testing::kAxisData4D, + tflite::testing::kOutputShape4D, tflite::testing::kGoldenData4D, output_data_quant, expected_output_data_quant, output_scale, output_zero_point, ¶ms); } TF_LITE_MICRO_TEST(MeanFloat4DWithoutKeepDims) { - const int kOutputShape[] = {2, 2, 2}; - float output_data[tflite::testing::kOutputElements]; + const int kOutputShape4D[] = {2, 2, 2}; + float output_data[tflite::testing::kOutputElements4D]; TfLiteReducerParams params = { false // keep_dims }; tflite::testing::TestMeanFloatInput4D( tflite::testing::kInputShape4D, tflite::testing::kInputData4D, - tflite::testing::kAxisShape, tflite::testing::kAxisData, kOutputShape, - tflite::testing::kGoldenData, output_data, ¶ms); + tflite::testing::kAxisShape4D, tflite::testing::kAxisData4D, + kOutputShape4D, tflite::testing::kGoldenData4D, output_data, ¶ms); } TF_LITE_MICRO_TEST(MeanInt84DWithoutKeepDims) { - int8_t expected_output_data_quant[tflite::testing::kOutputElements]; - int8_t output_data_quant[tflite::testing::kOutputElements]; + int8_t expected_output_data_quant[tflite::testing::kOutputElements4D]; + int8_t output_data_quant[tflite::testing::kOutputElements4D]; int8_t input_data_quant[tflite::testing::kInputElements4D]; - const int kOutputShape[] = {2, 2, 2}; + const int kOutputShape4D[] = {2, 2, 2}; TfLiteReducerParams params = { false // keep_dims }; @@ -292,17 +433,17 @@ TF_LITE_MICRO_TEST(MeanInt84DWithoutKeepDims) { tflite::testing::TestMeanOpQuantized( tflite::testing::kInputShape4D, tflite::testing::kInputData4D, input_data_quant, input_scale, input_zero_point, - tflite::testing::kAxisShape, tflite::testing::kAxisData, kOutputShape, - tflite::testing::kGoldenData, output_data_quant, + tflite::testing::kAxisShape4D, tflite::testing::kAxisData4D, + kOutputShape4D, tflite::testing::kGoldenData4D, output_data_quant, expected_output_data_quant, output_scale, output_zero_point, ¶ms); } TF_LITE_MICRO_TEST(MeanUInt84DWithoutKeepDims) { - uint8_t expected_output_data_quant[tflite::testing::kOutputElements]; - uint8_t output_data_quant[tflite::testing::kOutputElements]; + uint8_t expected_output_data_quant[tflite::testing::kOutputElements4D]; + uint8_t output_data_quant[tflite::testing::kOutputElements4D]; uint8_t input_data_quant[tflite::testing::kInputElements4D]; - const int kOutputShape[] = {2, 2, 2}; + const int kOutputShape4D[] = {2, 2, 2}; TfLiteReducerParams params = { false // keep_dims }; @@ -314,8 +455,8 @@ TF_LITE_MICRO_TEST(MeanUInt84DWithoutKeepDims) { tflite::testing::TestMeanOpQuantized( tflite::testing::kInputShape4D, tflite::testing::kInputData4D, input_data_quant, input_scale, input_zero_point, - tflite::testing::kAxisShape, tflite::testing::kAxisData, kOutputShape, - tflite::testing::kGoldenData, output_data_quant, + tflite::testing::kAxisShape4D, tflite::testing::kAxisData4D, + kOutputShape4D, tflite::testing::kGoldenData4D, output_data_quant, expected_output_data_quant, output_scale, output_zero_point, ¶ms); } @@ -323,17 +464,17 @@ TF_LITE_MICRO_TEST(MeanFloat4DWithoutKeepDimsWithPrecision) { const int kInputShape4D[] = {4, 2, 2, 3, 1}; const float kInputData4D[] = {1.0, 24.0, 13.0, 3.0, 9.0, 17.0, 11.0, 36.0, 14.0, 19.0, 17.0, 22.0}; - const int kOutputElements = 2; - const int kOutputShape[] = {2, 2, 1}; - const float kGoldenData[] = {11.166667, 19.833334}; - float output_data[kOutputElements]; + const int kOutputElements4D = 2; + const int kOutputShape4D[] = {2, 2, 1}; + const float kGoldenData4D[] = {11.166667, 19.833334}; + float output_data[kOutputElements4D]; TfLiteReducerParams params = { false // keep_dims }; tflite::testing::TestMeanFloatInput4D( - kInputShape4D, kInputData4D, tflite::testing::kAxisShape, - tflite::testing::kAxisData, kOutputShape, kGoldenData, output_data, + kInputShape4D, kInputData4D, tflite::testing::kAxisShape4D, + tflite::testing::kAxisData4D, kOutputShape4D, kGoldenData4D, output_data, ¶ms); } @@ -427,8 +568,8 @@ TF_LITE_MICRO_TEST(MeanInt84DWithoutKeepDimsWithPrecision) { const int kInputShape4D[] = {4, 2, 2, 3, 1}; const float kInputData4D[] = {1.0, 24.0, 13.0, 3.0, 9.0, 17.0, 11.0, 36.0, 14.0, 19.0, 17.0, 22.0}; - const int kOutputShape[] = {2, 2, 1}; - const float kGoldenData[] = {11.166667, 19.833334}; + const int kOutputShape4D[] = {2, 2, 1}; + const float kGoldenData4D[] = {11.166667, 19.833334}; TfLiteReducerParams params = { false // keep_dims }; @@ -443,17 +584,18 @@ TF_LITE_MICRO_TEST(MeanInt84DWithoutKeepDimsWithPrecision) { tflite::testing::TestMeanOpQuantized( kInputShape4D, kInputData4D, input_data_quant, input_scale, - input_zero_point, tflite::testing::kAxisShape, tflite::testing::kAxisData, - kOutputShape, kGoldenData, output_data_quant, expected_output_data_quant, - output_scale, output_zero_point, ¶ms); + input_zero_point, tflite::testing::kAxisShape4D, + tflite::testing::kAxisData4D, kOutputShape4D, kGoldenData4D, + output_data_quant, expected_output_data_quant, output_scale, + output_zero_point, ¶ms); } TF_LITE_MICRO_TEST(MeanUInt84DWithoutKeepDimsWithPrecision) { const int kInputShape4D[] = {4, 2, 2, 3, 1}; const float kInputData4D[] = {1.0, 24.0, 13.0, 3.0, 9.0, 17.0, 11.0, 36.0, 14.0, 19.0, 17.0, 22.0}; - const int kOutputShape[] = {2, 2, 1}; - const float kGoldenData[] = {11.166667, 19.833334}; + const int kOutputShape4D[] = {2, 2, 1}; + const float kGoldenData4D[] = {11.166667, 19.833334}; TfLiteReducerParams params = { false // keep_dims }; @@ -469,8 +611,9 @@ TF_LITE_MICRO_TEST(MeanUInt84DWithoutKeepDimsWithPrecision) { tflite::testing::TestMeanOpQuantized( kInputShape4D, kInputData4D, input_data_quant, input_scale, - input_zero_point, tflite::testing::kAxisShape, tflite::testing::kAxisData, - kOutputShape, kGoldenData, output_data_quant, expected_output_data_quant, - output_scale, output_zero_point, ¶ms); + input_zero_point, tflite::testing::kAxisShape4D, + tflite::testing::kAxisData4D, kOutputShape4D, kGoldenData4D, + output_data_quant, expected_output_data_quant, output_scale, + output_zero_point, ¶ms); } TF_LITE_MICRO_TESTS_END