diff --git a/tensorflow/lite/micro/kernels/reshape_test.cc b/tensorflow/lite/micro/kernels/reshape_test.cc index 91ecbdc7a49..48d1956f1c8 100644 --- a/tensorflow/lite/micro/kernels/reshape_test.cc +++ b/tensorflow/lite/micro/kernels/reshape_test.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/lite/micro/micro_utils.h" #include "tensorflow/lite/micro/test_helpers.h" #include "tensorflow/lite/micro/testing/micro_test.h" -#include "tensorflow/lite/micro/testing/test_utils.h" namespace tflite { namespace testing { @@ -113,22 +112,41 @@ void TestReshapeWithoutShape(TfLiteTensor* input_tensor, expected_dims_len, expect_failure); } -template -void TestReshape(const int* input_dims_data, const T* input_data, +void TestReshape(const int* input_dims_data, const float* input_data, const int* shape_dims_data, const int32_t* shape_data, - int* output_dims_data, T* output_data, - const T* expected_output, const size_t expected_output_len, + int* output_dims_data, float* output_data, + const float* expected_output, const size_t expected_output_len, const int* expected_dims, const size_t expected_dims_len, bool expect_failure = false) { TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); TfLiteIntArray* shape_dims = IntArrayFromInts(shape_dims_data); TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data); - TfLiteTensor input_tensor = - CreateTensor(input_data, input_dims); - TfLiteTensor shape_tensor = - CreateTensor(shape_data, shape_dims); - TfLiteTensor output_tensor = - CreateTensor(output_data, output_dims); + TfLiteTensor input_tensor = CreateFloatTensor(input_data, input_dims); + TfLiteTensor shape_tensor = CreateInt32Tensor(shape_data, shape_dims); + TfLiteTensor output_tensor = CreateFloatTensor(output_data, output_dims); + + TestReshapeWithShape(&input_tensor, &shape_tensor, &output_tensor, + expected_output, expected_output_len, expected_dims, + expected_dims_len, expect_failure); +} + +template +void TestReshapeQuantized(const int* input_dims_data, const T* input_data, + const int* shape_dims_data, const int32_t* shape_data, + int* output_dims_data, T* output_data, + const T* expected_output, + const size_t expected_output_len, + const int* expected_dims, + const size_t expected_dims_len, + bool expect_failure = false) { + TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); + TfLiteIntArray* shape_dims = IntArrayFromInts(shape_dims_data); + TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data); + TfLiteTensor input_tensor = CreateQuantizedTensor( + input_data, input_dims, /*scale=*/1.f, /*zero_point=*/0); + TfLiteTensor shape_tensor = CreateInt32Tensor(shape_data, shape_dims); + TfLiteTensor output_tensor = CreateQuantizedTensor( + output_data, output_dims, /*scale=*/1.f, /*zero_point=*/0); TestReshapeWithShape(&input_tensor, &shape_tensor, &output_tensor, expected_output, expected_output_len, expected_dims, @@ -233,11 +251,11 @@ TF_LITE_MICRO_TEST(ReshapeWithRegularShapesShouldSucceed) { output_dims, output_data_float, golden_output_float, golden_output_len, golden_dims, golden_dims_len, false); - tflite::testing::TestReshape( + tflite::testing::TestReshapeQuantized( input_dims, input_int8, shape_dims, shape_int32, output_dims, output_data_int8, golden_output_int8, golden_output_len, golden_dims, golden_dims_len, false); - tflite::testing::TestReshape( + tflite::testing::TestReshapeQuantized( input_dims, input_uint8, shape_dims, shape_int32, output_dims, output_data_uint8, golden_output_uint8, golden_output_len, golden_dims, golden_dims_len, false); @@ -265,11 +283,11 @@ TF_LITE_MICRO_TEST(ReshapeWithStretchDimensionShouldSucceed) { output_dims, output_data_float, golden_output_float, golden_output_len, golden_dims, golden_dims_len, false); - tflite::testing::TestReshape( + tflite::testing::TestReshapeQuantized( input_dims, input_int8, shape_dims, shape_int32, output_dims, output_data_int8, golden_output_int8, golden_output_len, golden_dims, golden_dims_len, false); - tflite::testing::TestReshape( + tflite::testing::TestReshapeQuantized( input_dims, input_uint8, shape_dims, shape_int32, output_dims, output_data_uint8, golden_output_uint8, golden_output_len, golden_dims, golden_dims_len, false); @@ -297,11 +315,11 @@ TF_LITE_MICRO_TEST(ReshapeWithScalarOutputShouldSucceed) { output_dims, output_data_float, golden_output_float, golden_output_len, golden_dims, golden_dims_len, false); - tflite::testing::TestReshape( + tflite::testing::TestReshapeQuantized( input_dims, input_int8, shape_dims, shape_int32, output_dims, output_data_int8, golden_output_int8, golden_output_len, golden_dims, golden_dims_len, false); - tflite::testing::TestReshape( + tflite::testing::TestReshapeQuantized( input_dims, input_uint8, shape_dims, shape_int32, output_dims, output_data_uint8, golden_output_uint8, golden_output_len, golden_dims, golden_dims_len, false); @@ -327,8 +345,8 @@ TF_LITE_MICRO_TEST(ReshapeWithLegacyScalarOutputShouldSucceed) { TfLiteIntArray* shape_dims = IntArrayFromInts(shape_dims_data); const int32_t shape_data[] = {0}; - auto shape_tensor = tflite::testing::CreateTensor( - shape_data, shape_dims); + auto shape_tensor = + tflite::testing::CreateInt32Tensor(shape_data, shape_dims); const float expected_output_with_shape[] = {}; const int expected_output_with_shape_len = 0; const float expected_output_no_shape[] = {3};