Add support for non-4D input tensors in MEAN operator.

Fixes:#43332

PiperOrigin-RevId: 332954480
Change-Id: Ib2291459da50d5458bb34a139a2a10db6a98d0df
This commit is contained in:
Nat Jeffries 2020-09-21 15:58:41 -07:00 committed by TensorFlower Gardener
parent af51ccb3ae
commit 9a662b14ea
2 changed files with 204 additions and 64 deletions

View File

@ -150,21 +150,17 @@ TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) {
tflite::MeanParams op_params;
ResolveAxis(tflite::micro::GetTensorData<int>(axis), num_axis, &op_params);
// TODO(b/146571391): Support only 4D Input and 2D Axis for Mean until
// scratch tensor allocation has been implemented in (b/132070898)
bool is_valid_inputs = (input->dims->size == 4 && op_params.axis_count == 2 &&
((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
(op_params.axis[0] == 2 && op_params.axis[1] == 1)));
TF_LITE_ENSURE_MSG(
context, is_valid_inputs == true,
"Number of Input "
"dimensions != 4 OR the Axis is not either [1, 2] or [2, 1]");
// Special case mean implementation exists for 4D mean across axes 1 and 2.
bool special_case_4d_axes_1_and_2 =
input->dims->size == 4 && op_params.axis_count == 2 &&
((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
(op_params.axis[0] == 2 && op_params.axis[1] == 1));
switch (input->type) {
case kTfLiteFloat32: {
// TODO(b/139102329): Handle the below special case in the combined
// reference method.
// Defer to specialized implementation for 4D Mean across axes 1 & 2.
if (params->keep_dims) {
if (params->keep_dims && special_case_4d_axes_1_and_2) {
reference_ops::Mean(op_params, tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<float>(input),
tflite::micro::GetTensorShape(output),
@ -182,7 +178,8 @@ TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) {
}
} break;
case kTfLiteInt8: {
if (params->keep_dims) {
// Defer to specialized implementation for 4D Mean across axes 1 & 2.
if (params->keep_dims && special_case_4d_axes_1_and_2) {
reference_integer_ops::Mean(
op_params, op_data->multiplier, op_data->shift,
tflite::micro::GetTensorShape(input),
@ -217,7 +214,8 @@ TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) {
}
} break;
case kTfLiteUInt8: {
if (params->keep_dims) {
// Defer to specialized implementation for 4D Mean across axes 1 & 2.
if (params->keep_dims && special_case_4d_axes_1_and_2) {
reference_ops::Mean(op_params, tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<uint8_t>(input),
op_data->input_zp, op_data->input_scale,
@ -253,7 +251,6 @@ TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) {
}
} break;
default:
// TODO(b/144955155): Support uint8_t(b/144955155) and int8_t(b/144955018)
TF_LITE_ENSURE_MSG(context, false,
"Currently, only float32, int8 or uint8 input type "
"is supported.");

View File

@ -24,20 +24,45 @@ namespace tflite {
namespace testing {
namespace {
// Common inputs and outputs.
// Common 2D inputs, outputs and axis.
static const int kInputElements2D = 8;
static const int kInputShape2D[] = {2, 2, 4};
static const float kInputData2D[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0};
static const int kAxisShape2D[] = {1, 1};
static const int32_t kAxisData2D[] = {1};
static const int kOutputElements2D = 2;
static const int kOutputShape2D[] = {2, 1, 2};
static const float kGoldenData2D[] = {2.5, 6.5};
// Common 3D inputs, outputs and axis.
static const int kInputElements3D = 8;
static const int kInputShape3D[] = {3, 2, 2, 2};
static const float kInputData3D[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0};
static const int kAxisShape3D[] = {1, 2};
static const int32_t kAxisData3D[] = {1, 2};
static const int kOutputElements3D = 2;
static const int kOutputShape3D[] = {2, 1, 2};
static const float kGoldenData3D[] = {2.5, 6.5};
// Common 4D inputs, outputs and axis.
static const int kInputElements4D = 24;
static const int kInputShape4D[] = {4, 2, 2, 3, 2};
static const float kInputData4D[] = {
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
// static const int kAxisElements = 3;
static const int kAxisShape[] = {1, 2};
static const int32_t kAxisData[] = {1, 2};
static const int kAxisShape4D[] = {1, 2};
static const int32_t kAxisData4D[] = {1, 2};
static const int kOutputElements = 4;
static const int kOutputShape[] = {4, 2, 1, 1, 2};
static const float kGoldenData[] = {6, 7, 18, 19};
static const int kOutputElements4D = 4;
static const int kOutputShape4D[] = {4, 2, 1, 1, 2};
static const float kGoldenData4D[] = {6, 7, 18, 19};
// Axis shape and contents are independent of input / output dimensions.
template <typename T>
TfLiteStatus ValidateReduceGoldens(TfLiteTensor* tensors, int tensors_size,
@ -202,8 +227,124 @@ void TestMeanOpQuantized(const int* input_dims_data, const float* input_data,
TF_LITE_MICRO_TESTS_BEGIN
TF_LITE_MICRO_TEST(MeanFloat2DKeepDims) {
float output_data[tflite::testing::kOutputElements2D];
TfLiteReducerParams params = {true};
tflite::testing::TestMeanFloatInput4D(
tflite::testing::kInputShape2D, tflite::testing::kInputData2D,
tflite::testing::kAxisShape2D, tflite::testing::kAxisData2D,
tflite::testing::kOutputShape2D, tflite::testing::kGoldenData2D,
output_data, &params);
}
TF_LITE_MICRO_TEST(MeanInt82DKeepDims) {
int8_t expected_output_data_quant[tflite::testing::kOutputElements2D];
int8_t output_data_quant[tflite::testing::kOutputElements2D];
int8_t input_data_quant[tflite::testing::kInputElements2D];
float input_scale = 0.5f;
int input_zero_point = 0;
float output_scale = 0.5f;
int output_zero_point = 0;
TfLiteReducerParams params = {
true // keep_dims
};
tflite::testing::TestMeanOpQuantized<int8_t>(
tflite::testing::kInputShape2D, tflite::testing::kInputData2D,
input_data_quant, input_scale, input_zero_point,
tflite::testing::kAxisShape2D, tflite::testing::kAxisData2D,
tflite::testing::kOutputShape2D, tflite::testing::kGoldenData2D,
output_data_quant, expected_output_data_quant, output_scale,
output_zero_point, &params);
}
TF_LITE_MICRO_TEST(MeanUInt82DKeepDims) {
uint8_t expected_output_data_quant[tflite::testing::kOutputElements2D];
uint8_t output_data_quant[tflite::testing::kOutputElements2D];
uint8_t input_data_quant[tflite::testing::kInputElements2D];
float input_scale = 0.5f;
int input_zero_point = 128;
float output_scale = 0.5f;
int output_zero_point = 128;
TfLiteReducerParams params = {
true // keep_dims
};
tflite::testing::TestMeanOpQuantized<uint8_t>(
tflite::testing::kInputShape2D, tflite::testing::kInputData2D,
input_data_quant, input_scale, input_zero_point,
tflite::testing::kAxisShape2D, tflite::testing::kAxisData2D,
tflite::testing::kOutputShape2D, tflite::testing::kGoldenData2D,
output_data_quant, expected_output_data_quant, output_scale,
output_zero_point, &params);
}
TF_LITE_MICRO_TEST(MeanFloat3DKeepDims) {
float output_data[tflite::testing::kOutputElements3D];
TfLiteReducerParams params = {true};
tflite::testing::TestMeanFloatInput4D(
tflite::testing::kInputShape3D, tflite::testing::kInputData3D,
tflite::testing::kAxisShape3D, tflite::testing::kAxisData3D,
tflite::testing::kOutputShape3D, tflite::testing::kGoldenData3D,
output_data, &params);
}
TF_LITE_MICRO_TEST(MeanInt83DKeepDims) {
int8_t expected_output_data_quant[tflite::testing::kOutputElements3D];
int8_t output_data_quant[tflite::testing::kOutputElements3D];
int8_t input_data_quant[tflite::testing::kInputElements3D];
float input_scale = 0.5f;
int input_zero_point = 0;
float output_scale = 0.5f;
int output_zero_point = 0;
TfLiteReducerParams params = {
true // keep_dims
};
tflite::testing::TestMeanOpQuantized<int8_t>(
tflite::testing::kInputShape3D, tflite::testing::kInputData3D,
input_data_quant, input_scale, input_zero_point,
tflite::testing::kAxisShape3D, tflite::testing::kAxisData3D,
tflite::testing::kOutputShape3D, tflite::testing::kGoldenData3D,
output_data_quant, expected_output_data_quant, output_scale,
output_zero_point, &params);
}
TF_LITE_MICRO_TEST(MeanUInt83DKeepDims) {
uint8_t expected_output_data_quant[tflite::testing::kOutputElements3D];
uint8_t output_data_quant[tflite::testing::kOutputElements3D];
uint8_t input_data_quant[tflite::testing::kInputElements3D];
float input_scale = 0.5f;
int input_zero_point = 138;
float output_scale = 0.5f;
int output_zero_point = 138;
TfLiteReducerParams params = {
true // keep_dims
};
tflite::testing::TestMeanOpQuantized<uint8_t>(
tflite::testing::kInputShape3D, tflite::testing::kInputData3D,
input_data_quant, input_scale, input_zero_point,
tflite::testing::kAxisShape3D, tflite::testing::kAxisData3D,
tflite::testing::kOutputShape3D, tflite::testing::kGoldenData3D,
output_data_quant, expected_output_data_quant, output_scale,
output_zero_point, &params);
}
TF_LITE_MICRO_TEST(MeanFloat4DKeepDims) {
float output_data[tflite::testing::kOutputElements];
float output_data[tflite::testing::kOutputElements4D];
TfLiteReducerParams params = {
true // keep_dims
@ -211,14 +352,14 @@ TF_LITE_MICRO_TEST(MeanFloat4DKeepDims) {
tflite::testing::TestMeanFloatInput4D(
tflite::testing::kInputShape4D, tflite::testing::kInputData4D,
tflite::testing::kAxisShape, tflite::testing::kAxisData,
tflite::testing::kOutputShape, tflite::testing::kGoldenData, output_data,
&params);
tflite::testing::kAxisShape4D, tflite::testing::kAxisData4D,
tflite::testing::kOutputShape4D, tflite::testing::kGoldenData4D,
output_data, &params);
}
TF_LITE_MICRO_TEST(MeanInt84DKeepDims) {
int8_t expected_output_data_quant[tflite::testing::kOutputElements];
int8_t output_data_quant[tflite::testing::kOutputElements];
int8_t expected_output_data_quant[tflite::testing::kOutputElements4D];
int8_t output_data_quant[tflite::testing::kOutputElements4D];
int8_t input_data_quant[tflite::testing::kInputElements4D];
float input_scale = 0.5f;
@ -233,15 +374,15 @@ TF_LITE_MICRO_TEST(MeanInt84DKeepDims) {
tflite::testing::TestMeanOpQuantized<int8_t>(
tflite::testing::kInputShape4D, tflite::testing::kInputData4D,
input_data_quant, input_scale, input_zero_point,
tflite::testing::kAxisShape, tflite::testing::kAxisData,
tflite::testing::kOutputShape, tflite::testing::kGoldenData,
tflite::testing::kAxisShape4D, tflite::testing::kAxisData4D,
tflite::testing::kOutputShape4D, tflite::testing::kGoldenData4D,
output_data_quant, expected_output_data_quant, output_scale,
output_zero_point, &params);
}
TF_LITE_MICRO_TEST(MeanUInt84DKeepDims) {
uint8_t expected_output_data_quant[tflite::testing::kOutputElements];
uint8_t output_data_quant[tflite::testing::kOutputElements];
uint8_t expected_output_data_quant[tflite::testing::kOutputElements4D];
uint8_t output_data_quant[tflite::testing::kOutputElements4D];
uint8_t input_data_quant[tflite::testing::kInputElements4D];
float input_scale = 0.5f;
@ -256,31 +397,31 @@ TF_LITE_MICRO_TEST(MeanUInt84DKeepDims) {
tflite::testing::TestMeanOpQuantized<uint8_t>(
tflite::testing::kInputShape4D, tflite::testing::kInputData4D,
input_data_quant, input_scale, input_zero_point,
tflite::testing::kAxisShape, tflite::testing::kAxisData,
tflite::testing::kOutputShape, tflite::testing::kGoldenData,
tflite::testing::kAxisShape4D, tflite::testing::kAxisData4D,
tflite::testing::kOutputShape4D, tflite::testing::kGoldenData4D,
output_data_quant, expected_output_data_quant, output_scale,
output_zero_point, &params);
}
TF_LITE_MICRO_TEST(MeanFloat4DWithoutKeepDims) {
const int kOutputShape[] = {2, 2, 2};
float output_data[tflite::testing::kOutputElements];
const int kOutputShape4D[] = {2, 2, 2};
float output_data[tflite::testing::kOutputElements4D];
TfLiteReducerParams params = {
false // keep_dims
};
tflite::testing::TestMeanFloatInput4D(
tflite::testing::kInputShape4D, tflite::testing::kInputData4D,
tflite::testing::kAxisShape, tflite::testing::kAxisData, kOutputShape,
tflite::testing::kGoldenData, output_data, &params);
tflite::testing::kAxisShape4D, tflite::testing::kAxisData4D,
kOutputShape4D, tflite::testing::kGoldenData4D, output_data, &params);
}
TF_LITE_MICRO_TEST(MeanInt84DWithoutKeepDims) {
int8_t expected_output_data_quant[tflite::testing::kOutputElements];
int8_t output_data_quant[tflite::testing::kOutputElements];
int8_t expected_output_data_quant[tflite::testing::kOutputElements4D];
int8_t output_data_quant[tflite::testing::kOutputElements4D];
int8_t input_data_quant[tflite::testing::kInputElements4D];
const int kOutputShape[] = {2, 2, 2};
const int kOutputShape4D[] = {2, 2, 2};
TfLiteReducerParams params = {
false // keep_dims
};
@ -292,17 +433,17 @@ TF_LITE_MICRO_TEST(MeanInt84DWithoutKeepDims) {
tflite::testing::TestMeanOpQuantized<int8_t>(
tflite::testing::kInputShape4D, tflite::testing::kInputData4D,
input_data_quant, input_scale, input_zero_point,
tflite::testing::kAxisShape, tflite::testing::kAxisData, kOutputShape,
tflite::testing::kGoldenData, output_data_quant,
tflite::testing::kAxisShape4D, tflite::testing::kAxisData4D,
kOutputShape4D, tflite::testing::kGoldenData4D, output_data_quant,
expected_output_data_quant, output_scale, output_zero_point, &params);
}
TF_LITE_MICRO_TEST(MeanUInt84DWithoutKeepDims) {
uint8_t expected_output_data_quant[tflite::testing::kOutputElements];
uint8_t output_data_quant[tflite::testing::kOutputElements];
uint8_t expected_output_data_quant[tflite::testing::kOutputElements4D];
uint8_t output_data_quant[tflite::testing::kOutputElements4D];
uint8_t input_data_quant[tflite::testing::kInputElements4D];
const int kOutputShape[] = {2, 2, 2};
const int kOutputShape4D[] = {2, 2, 2};
TfLiteReducerParams params = {
false // keep_dims
};
@ -314,8 +455,8 @@ TF_LITE_MICRO_TEST(MeanUInt84DWithoutKeepDims) {
tflite::testing::TestMeanOpQuantized<uint8_t>(
tflite::testing::kInputShape4D, tflite::testing::kInputData4D,
input_data_quant, input_scale, input_zero_point,
tflite::testing::kAxisShape, tflite::testing::kAxisData, kOutputShape,
tflite::testing::kGoldenData, output_data_quant,
tflite::testing::kAxisShape4D, tflite::testing::kAxisData4D,
kOutputShape4D, tflite::testing::kGoldenData4D, output_data_quant,
expected_output_data_quant, output_scale, output_zero_point, &params);
}
@ -323,17 +464,17 @@ TF_LITE_MICRO_TEST(MeanFloat4DWithoutKeepDimsWithPrecision) {
const int kInputShape4D[] = {4, 2, 2, 3, 1};
const float kInputData4D[] = {1.0, 24.0, 13.0, 3.0, 9.0, 17.0,
11.0, 36.0, 14.0, 19.0, 17.0, 22.0};
const int kOutputElements = 2;
const int kOutputShape[] = {2, 2, 1};
const float kGoldenData[] = {11.166667, 19.833334};
float output_data[kOutputElements];
const int kOutputElements4D = 2;
const int kOutputShape4D[] = {2, 2, 1};
const float kGoldenData4D[] = {11.166667, 19.833334};
float output_data[kOutputElements4D];
TfLiteReducerParams params = {
false // keep_dims
};
tflite::testing::TestMeanFloatInput4D(
kInputShape4D, kInputData4D, tflite::testing::kAxisShape,
tflite::testing::kAxisData, kOutputShape, kGoldenData, output_data,
kInputShape4D, kInputData4D, tflite::testing::kAxisShape4D,
tflite::testing::kAxisData4D, kOutputShape4D, kGoldenData4D, output_data,
&params);
}
@ -427,8 +568,8 @@ TF_LITE_MICRO_TEST(MeanInt84DWithoutKeepDimsWithPrecision) {
const int kInputShape4D[] = {4, 2, 2, 3, 1};
const float kInputData4D[] = {1.0, 24.0, 13.0, 3.0, 9.0, 17.0,
11.0, 36.0, 14.0, 19.0, 17.0, 22.0};
const int kOutputShape[] = {2, 2, 1};
const float kGoldenData[] = {11.166667, 19.833334};
const int kOutputShape4D[] = {2, 2, 1};
const float kGoldenData4D[] = {11.166667, 19.833334};
TfLiteReducerParams params = {
false // keep_dims
};
@ -443,17 +584,18 @@ TF_LITE_MICRO_TEST(MeanInt84DWithoutKeepDimsWithPrecision) {
tflite::testing::TestMeanOpQuantized<int8_t>(
kInputShape4D, kInputData4D, input_data_quant, input_scale,
input_zero_point, tflite::testing::kAxisShape, tflite::testing::kAxisData,
kOutputShape, kGoldenData, output_data_quant, expected_output_data_quant,
output_scale, output_zero_point, &params);
input_zero_point, tflite::testing::kAxisShape4D,
tflite::testing::kAxisData4D, kOutputShape4D, kGoldenData4D,
output_data_quant, expected_output_data_quant, output_scale,
output_zero_point, &params);
}
TF_LITE_MICRO_TEST(MeanUInt84DWithoutKeepDimsWithPrecision) {
const int kInputShape4D[] = {4, 2, 2, 3, 1};
const float kInputData4D[] = {1.0, 24.0, 13.0, 3.0, 9.0, 17.0,
11.0, 36.0, 14.0, 19.0, 17.0, 22.0};
const int kOutputShape[] = {2, 2, 1};
const float kGoldenData[] = {11.166667, 19.833334};
const int kOutputShape4D[] = {2, 2, 1};
const float kGoldenData4D[] = {11.166667, 19.833334};
TfLiteReducerParams params = {
false // keep_dims
};
@ -469,8 +611,9 @@ TF_LITE_MICRO_TEST(MeanUInt84DWithoutKeepDimsWithPrecision) {
tflite::testing::TestMeanOpQuantized<uint8_t>(
kInputShape4D, kInputData4D, input_data_quant, input_scale,
input_zero_point, tflite::testing::kAxisShape, tflite::testing::kAxisData,
kOutputShape, kGoldenData, output_data_quant, expected_output_data_quant,
output_scale, output_zero_point, &params);
input_zero_point, tflite::testing::kAxisShape4D,
tflite::testing::kAxisData4D, kOutputShape4D, kGoldenData4D,
output_data_quant, expected_output_data_quant, output_scale,
output_zero_point, &params);
}
TF_LITE_MICRO_TESTS_END