From d0322efb2869e04540c1e1e3b70604891aaae91a Mon Sep 17 00:00:00 2001 From: Thai Nguyen <thaink@google.com> Date: Wed, 12 Feb 2020 23:19:11 -0800 Subject: [PATCH] Remove dimension check in TFLite unpack The current implementation can support arbitrary dimension PiperOrigin-RevId: 294843677 Change-Id: Id22d4e360a22704f90345886f5b04465c54462e6 --- tensorflow/lite/kernels/unpack.cc | 1 - tensorflow/lite/kernels/unpack_test.cc | 52 ++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 1 deletion(-) diff --git a/tensorflow/lite/kernels/unpack.cc b/tensorflow/lite/kernels/unpack.cc index 8e66432e9cd..9ddee6b30bd 100644 --- a/tensorflow/lite/kernels/unpack.cc +++ b/tensorflow/lite/kernels/unpack.cc @@ -35,7 +35,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumOutputs(node), data->num); const TfLiteTensor* input = GetInput(context, node, kInputTensor); - TF_LITE_ENSURE(context, NumDimensions(input) <= 4); TF_LITE_ENSURE(context, NumElements(input) > 0); int axis = data->axis; if (axis < 0) { diff --git a/tensorflow/lite/kernels/unpack_test.cc b/tensorflow/lite/kernels/unpack_test.cc index 88eb706e969..9413d5e2873 100644 --- a/tensorflow/lite/kernels/unpack_test.cc +++ b/tensorflow/lite/kernels/unpack_test.cc @@ -126,6 +126,15 @@ TEST(UnpackOpTest, FloatThreeDimensionsOutputs) { /*exp_output_data=*/{{1, 3, 5, 7}, {2, 4, 6, 8}}); } +TEST(UnpackOpTest, FloatFiveDimensionsOutputs) { + Check<float>( + /*axis=*/2, /*input_shape=*/{2, 2, 2, 2, 1}, + /*input_data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + /*exp_output_shape=*/{{2, 2, 2, 1}, {2, 2, 2, 1}}, + /*exp_output_data=*/ + {{1, 2, 5, 6, 9, 10, 13, 14}, {3, 4, 7, 8, 11, 12, 15, 16}}); +} + TEST(UnpackOpTest, FloatVectorToScalar) { Check<float>(/*axis=*/0, /*input_shape=*/{5}, /*input_data=*/{1, 2, 3, 4, 5}, @@ -166,6 +175,16 @@ TEST(UnpackOpTest, IntThreeDimensionsOutputs) { /*type=*/TensorType_INT32); } +TEST(UnpackOpTest, IntFiveDimensionsOutputs) { + Check<int32_t>( + /*axis=*/2, /*input_shape=*/{2, 2, 2, 2, 1}, + /*input_data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + /*exp_output_shape=*/{{2, 2, 2, 1}, {2, 2, 2, 1}}, + /*exp_output_data=*/ + {{1, 2, 5, 6, 9, 10, 13, 14}, {3, 4, 7, 8, 11, 12, 15, 16}}, + /*type=*/TensorType_INT32); +} + TEST(UnpackOpTest, IntVectorToScalar) { Check<int32_t>(/*axis=*/0, /*input_shape=*/{5}, /*input_data=*/{1, 2, 3, 4, 5}, @@ -223,6 +242,16 @@ TEST(UnpackOpTest, Uint8ThreeDimensionsOutputs) { /*type=*/TensorType_UINT8); } +TEST(UnpackOpTest, Uint8FiveDimensionsOutputs) { + Check<uint8_t>( + /*axis=*/2, /*input_shape=*/{2, 2, 2, 2, 1}, + /*input_data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + /*exp_output_shape=*/{{2, 2, 2, 1}, {2, 2, 2, 1}}, + /*exp_output_data=*/ + {{1, 2, 5, 6, 9, 10, 13, 14}, {3, 4, 7, 8, 11, 12, 15, 16}}, + /*type=*/TensorType_UINT8); +} + TEST(UnpackOpTest, Uint8VectorToScalar) { Check<uint8_t>(/*axis=*/0, /*input_shape=*/{5}, /*input_data=*/{1, 2, 3, 4, 5}, @@ -280,6 +309,16 @@ TEST(UnpackOpTest, Int8ThreeDimensionsOutputs) { /*type=*/TensorType_INT8); } +TEST(UnpackOpTest, Int8FiveDimensionsOutputs) { + Check<int8_t>( + /*axis=*/2, /*input_shape=*/{2, 2, 2, 2, 1}, + /*input_data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + /*exp_output_shape=*/{{2, 2, 2, 1}, {2, 2, 2, 1}}, + /*exp_output_data=*/ + {{1, 2, 5, 6, 9, 10, 13, 14}, {3, 4, 7, 8, 11, 12, 15, 16}}, + /*type=*/TensorType_INT8); +} + TEST(UnpackOpTest, Int8VectorToScalar) { Check<int8_t>(/*axis=*/0, /*input_shape=*/{5}, /*input_data=*/{1, 2, 3, 4, 5}, @@ -344,6 +383,19 @@ TEST(UnpackOpTest, BoolThreeDimensionsOutputs) { /*type=*/TensorType_BOOL); } +TEST(UnpackOpTest, BoolFiveDimensionsOutputs) { + Check<bool>( + /*axis=*/2, /*input_shape=*/{2, 2, 2, 2, 1}, + /*input_data=*/ + {true, false, true, false, true, false, true, false, true, true, true, + true, true, true, true, true}, + /*exp_output_shape=*/{{2, 2, 2, 1}, {2, 2, 2, 1}}, + /*exp_output_data=*/ + {{true, false, true, false, true, true, true, true}, + {true, false, true, false, true, true, true, true}}, + /*type=*/TensorType_BOOL); +} + TEST(UnpackOpTest, BoolVectorToScalar) { Check<bool>(/*axis=*/0, /*input_shape=*/{5}, /*input_data=*/{true, false, true, false, true},