Remove dimension check in TFLite unpack

The current implementation can support arbitrary dimension

PiperOrigin-RevId: 294843677
Change-Id: Id22d4e360a22704f90345886f5b04465c54462e6
This commit is contained in:
Thai Nguyen 2020-02-12 23:19:11 -08:00 committed by TensorFlower Gardener
parent a4f980dc5e
commit d0322efb28
2 changed files with 52 additions and 1 deletions
tensorflow/lite/kernels

View File

@ -35,7 +35,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumOutputs(node), data->num);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TF_LITE_ENSURE(context, NumDimensions(input) <= 4);
TF_LITE_ENSURE(context, NumElements(input) > 0);
int axis = data->axis;
if (axis < 0) {

View File

@ -126,6 +126,15 @@ TEST(UnpackOpTest, FloatThreeDimensionsOutputs) {
/*exp_output_data=*/{{1, 3, 5, 7}, {2, 4, 6, 8}});
}
TEST(UnpackOpTest, FloatFiveDimensionsOutputs) {
Check<float>(
/*axis=*/2, /*input_shape=*/{2, 2, 2, 2, 1},
/*input_data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
/*exp_output_shape=*/{{2, 2, 2, 1}, {2, 2, 2, 1}},
/*exp_output_data=*/
{{1, 2, 5, 6, 9, 10, 13, 14}, {3, 4, 7, 8, 11, 12, 15, 16}});
}
TEST(UnpackOpTest, FloatVectorToScalar) {
Check<float>(/*axis=*/0, /*input_shape=*/{5},
/*input_data=*/{1, 2, 3, 4, 5},
@ -166,6 +175,16 @@ TEST(UnpackOpTest, IntThreeDimensionsOutputs) {
/*type=*/TensorType_INT32);
}
TEST(UnpackOpTest, IntFiveDimensionsOutputs) {
Check<int32_t>(
/*axis=*/2, /*input_shape=*/{2, 2, 2, 2, 1},
/*input_data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
/*exp_output_shape=*/{{2, 2, 2, 1}, {2, 2, 2, 1}},
/*exp_output_data=*/
{{1, 2, 5, 6, 9, 10, 13, 14}, {3, 4, 7, 8, 11, 12, 15, 16}},
/*type=*/TensorType_INT32);
}
TEST(UnpackOpTest, IntVectorToScalar) {
Check<int32_t>(/*axis=*/0, /*input_shape=*/{5},
/*input_data=*/{1, 2, 3, 4, 5},
@ -223,6 +242,16 @@ TEST(UnpackOpTest, Uint8ThreeDimensionsOutputs) {
/*type=*/TensorType_UINT8);
}
TEST(UnpackOpTest, Uint8FiveDimensionsOutputs) {
Check<uint8_t>(
/*axis=*/2, /*input_shape=*/{2, 2, 2, 2, 1},
/*input_data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
/*exp_output_shape=*/{{2, 2, 2, 1}, {2, 2, 2, 1}},
/*exp_output_data=*/
{{1, 2, 5, 6, 9, 10, 13, 14}, {3, 4, 7, 8, 11, 12, 15, 16}},
/*type=*/TensorType_UINT8);
}
TEST(UnpackOpTest, Uint8VectorToScalar) {
Check<uint8_t>(/*axis=*/0, /*input_shape=*/{5},
/*input_data=*/{1, 2, 3, 4, 5},
@ -280,6 +309,16 @@ TEST(UnpackOpTest, Int8ThreeDimensionsOutputs) {
/*type=*/TensorType_INT8);
}
TEST(UnpackOpTest, Int8FiveDimensionsOutputs) {
Check<int8_t>(
/*axis=*/2, /*input_shape=*/{2, 2, 2, 2, 1},
/*input_data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
/*exp_output_shape=*/{{2, 2, 2, 1}, {2, 2, 2, 1}},
/*exp_output_data=*/
{{1, 2, 5, 6, 9, 10, 13, 14}, {3, 4, 7, 8, 11, 12, 15, 16}},
/*type=*/TensorType_INT8);
}
TEST(UnpackOpTest, Int8VectorToScalar) {
Check<int8_t>(/*axis=*/0, /*input_shape=*/{5},
/*input_data=*/{1, 2, 3, 4, 5},
@ -344,6 +383,19 @@ TEST(UnpackOpTest, BoolThreeDimensionsOutputs) {
/*type=*/TensorType_BOOL);
}
TEST(UnpackOpTest, BoolFiveDimensionsOutputs) {
Check<bool>(
/*axis=*/2, /*input_shape=*/{2, 2, 2, 2, 1},
/*input_data=*/
{true, false, true, false, true, false, true, false, true, true, true,
true, true, true, true, true},
/*exp_output_shape=*/{{2, 2, 2, 1}, {2, 2, 2, 1}},
/*exp_output_data=*/
{{true, false, true, false, true, true, true, true},
{true, false, true, false, true, true, true, true}},
/*type=*/TensorType_BOOL);
}
TEST(UnpackOpTest, BoolVectorToScalar) {
Check<bool>(/*axis=*/0, /*input_shape=*/{5},
/*input_data=*/{true, false, true, false, true},