From d0322efb2869e04540c1e1e3b70604891aaae91a Mon Sep 17 00:00:00 2001
From: Thai Nguyen <thaink@google.com>
Date: Wed, 12 Feb 2020 23:19:11 -0800
Subject: [PATCH] Remove dimension check in TFLite unpack

The current implementation can support arbitrary dimension

PiperOrigin-RevId: 294843677
Change-Id: Id22d4e360a22704f90345886f5b04465c54462e6
---
 tensorflow/lite/kernels/unpack.cc      |  1 -
 tensorflow/lite/kernels/unpack_test.cc | 52 ++++++++++++++++++++++++++
 2 files changed, 52 insertions(+), 1 deletion(-)

diff --git a/tensorflow/lite/kernels/unpack.cc b/tensorflow/lite/kernels/unpack.cc
index 8e66432e9cd..9ddee6b30bd 100644
--- a/tensorflow/lite/kernels/unpack.cc
+++ b/tensorflow/lite/kernels/unpack.cc
@@ -35,7 +35,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
   TF_LITE_ENSURE_EQ(context, NumOutputs(node), data->num);
 
   const TfLiteTensor* input = GetInput(context, node, kInputTensor);
-  TF_LITE_ENSURE(context, NumDimensions(input) <= 4);
   TF_LITE_ENSURE(context, NumElements(input) > 0);
   int axis = data->axis;
   if (axis < 0) {
diff --git a/tensorflow/lite/kernels/unpack_test.cc b/tensorflow/lite/kernels/unpack_test.cc
index 88eb706e969..9413d5e2873 100644
--- a/tensorflow/lite/kernels/unpack_test.cc
+++ b/tensorflow/lite/kernels/unpack_test.cc
@@ -126,6 +126,15 @@ TEST(UnpackOpTest, FloatThreeDimensionsOutputs) {
                /*exp_output_data=*/{{1, 3, 5, 7}, {2, 4, 6, 8}});
 }
 
+TEST(UnpackOpTest, FloatFiveDimensionsOutputs) {
+  Check<float>(
+      /*axis=*/2, /*input_shape=*/{2, 2, 2, 2, 1},
+      /*input_data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
+      /*exp_output_shape=*/{{2, 2, 2, 1}, {2, 2, 2, 1}},
+      /*exp_output_data=*/
+      {{1, 2, 5, 6, 9, 10, 13, 14}, {3, 4, 7, 8, 11, 12, 15, 16}});
+}
+
 TEST(UnpackOpTest, FloatVectorToScalar) {
   Check<float>(/*axis=*/0, /*input_shape=*/{5},
                /*input_data=*/{1, 2, 3, 4, 5},
@@ -166,6 +175,16 @@ TEST(UnpackOpTest, IntThreeDimensionsOutputs) {
                  /*type=*/TensorType_INT32);
 }
 
+TEST(UnpackOpTest, IntFiveDimensionsOutputs) {
+  Check<int32_t>(
+      /*axis=*/2, /*input_shape=*/{2, 2, 2, 2, 1},
+      /*input_data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
+      /*exp_output_shape=*/{{2, 2, 2, 1}, {2, 2, 2, 1}},
+      /*exp_output_data=*/
+      {{1, 2, 5, 6, 9, 10, 13, 14}, {3, 4, 7, 8, 11, 12, 15, 16}},
+      /*type=*/TensorType_INT32);
+}
+
 TEST(UnpackOpTest, IntVectorToScalar) {
   Check<int32_t>(/*axis=*/0, /*input_shape=*/{5},
                  /*input_data=*/{1, 2, 3, 4, 5},
@@ -223,6 +242,16 @@ TEST(UnpackOpTest, Uint8ThreeDimensionsOutputs) {
                  /*type=*/TensorType_UINT8);
 }
 
+TEST(UnpackOpTest, Uint8FiveDimensionsOutputs) {
+  Check<uint8_t>(
+      /*axis=*/2, /*input_shape=*/{2, 2, 2, 2, 1},
+      /*input_data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
+      /*exp_output_shape=*/{{2, 2, 2, 1}, {2, 2, 2, 1}},
+      /*exp_output_data=*/
+      {{1, 2, 5, 6, 9, 10, 13, 14}, {3, 4, 7, 8, 11, 12, 15, 16}},
+      /*type=*/TensorType_UINT8);
+}
+
 TEST(UnpackOpTest, Uint8VectorToScalar) {
   Check<uint8_t>(/*axis=*/0, /*input_shape=*/{5},
                  /*input_data=*/{1, 2, 3, 4, 5},
@@ -280,6 +309,16 @@ TEST(UnpackOpTest, Int8ThreeDimensionsOutputs) {
                 /*type=*/TensorType_INT8);
 }
 
+TEST(UnpackOpTest, Int8FiveDimensionsOutputs) {
+  Check<int8_t>(
+      /*axis=*/2, /*input_shape=*/{2, 2, 2, 2, 1},
+      /*input_data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
+      /*exp_output_shape=*/{{2, 2, 2, 1}, {2, 2, 2, 1}},
+      /*exp_output_data=*/
+      {{1, 2, 5, 6, 9, 10, 13, 14}, {3, 4, 7, 8, 11, 12, 15, 16}},
+      /*type=*/TensorType_INT8);
+}
+
 TEST(UnpackOpTest, Int8VectorToScalar) {
   Check<int8_t>(/*axis=*/0, /*input_shape=*/{5},
                 /*input_data=*/{1, 2, 3, 4, 5},
@@ -344,6 +383,19 @@ TEST(UnpackOpTest, BoolThreeDimensionsOutputs) {
       /*type=*/TensorType_BOOL);
 }
 
+TEST(UnpackOpTest, BoolFiveDimensionsOutputs) {
+  Check<bool>(
+      /*axis=*/2, /*input_shape=*/{2, 2, 2, 2, 1},
+      /*input_data=*/
+      {true, false, true, false, true, false, true, false, true, true, true,
+       true, true, true, true, true},
+      /*exp_output_shape=*/{{2, 2, 2, 1}, {2, 2, 2, 1}},
+      /*exp_output_data=*/
+      {{true, false, true, false, true, true, true, true},
+       {true, false, true, false, true, true, true, true}},
+      /*type=*/TensorType_BOOL);
+}
+
 TEST(UnpackOpTest, BoolVectorToScalar) {
   Check<bool>(/*axis=*/0, /*input_shape=*/{5},
               /*input_data=*/{true, false, true, false, true},