Remove dimension check in TFLite unpack
The current implementation can support arbitrary dimension PiperOrigin-RevId: 294843677 Change-Id: Id22d4e360a22704f90345886f5b04465c54462e6
This commit is contained in:
parent
a4f980dc5e
commit
d0322efb28
tensorflow/lite/kernels
@ -35,7 +35,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), data->num);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TF_LITE_ENSURE(context, NumDimensions(input) <= 4);
|
||||
TF_LITE_ENSURE(context, NumElements(input) > 0);
|
||||
int axis = data->axis;
|
||||
if (axis < 0) {
|
||||
|
@ -126,6 +126,15 @@ TEST(UnpackOpTest, FloatThreeDimensionsOutputs) {
|
||||
/*exp_output_data=*/{{1, 3, 5, 7}, {2, 4, 6, 8}});
|
||||
}
|
||||
|
||||
TEST(UnpackOpTest, FloatFiveDimensionsOutputs) {
|
||||
Check<float>(
|
||||
/*axis=*/2, /*input_shape=*/{2, 2, 2, 2, 1},
|
||||
/*input_data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
|
||||
/*exp_output_shape=*/{{2, 2, 2, 1}, {2, 2, 2, 1}},
|
||||
/*exp_output_data=*/
|
||||
{{1, 2, 5, 6, 9, 10, 13, 14}, {3, 4, 7, 8, 11, 12, 15, 16}});
|
||||
}
|
||||
|
||||
TEST(UnpackOpTest, FloatVectorToScalar) {
|
||||
Check<float>(/*axis=*/0, /*input_shape=*/{5},
|
||||
/*input_data=*/{1, 2, 3, 4, 5},
|
||||
@ -166,6 +175,16 @@ TEST(UnpackOpTest, IntThreeDimensionsOutputs) {
|
||||
/*type=*/TensorType_INT32);
|
||||
}
|
||||
|
||||
TEST(UnpackOpTest, IntFiveDimensionsOutputs) {
|
||||
Check<int32_t>(
|
||||
/*axis=*/2, /*input_shape=*/{2, 2, 2, 2, 1},
|
||||
/*input_data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
|
||||
/*exp_output_shape=*/{{2, 2, 2, 1}, {2, 2, 2, 1}},
|
||||
/*exp_output_data=*/
|
||||
{{1, 2, 5, 6, 9, 10, 13, 14}, {3, 4, 7, 8, 11, 12, 15, 16}},
|
||||
/*type=*/TensorType_INT32);
|
||||
}
|
||||
|
||||
TEST(UnpackOpTest, IntVectorToScalar) {
|
||||
Check<int32_t>(/*axis=*/0, /*input_shape=*/{5},
|
||||
/*input_data=*/{1, 2, 3, 4, 5},
|
||||
@ -223,6 +242,16 @@ TEST(UnpackOpTest, Uint8ThreeDimensionsOutputs) {
|
||||
/*type=*/TensorType_UINT8);
|
||||
}
|
||||
|
||||
TEST(UnpackOpTest, Uint8FiveDimensionsOutputs) {
|
||||
Check<uint8_t>(
|
||||
/*axis=*/2, /*input_shape=*/{2, 2, 2, 2, 1},
|
||||
/*input_data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
|
||||
/*exp_output_shape=*/{{2, 2, 2, 1}, {2, 2, 2, 1}},
|
||||
/*exp_output_data=*/
|
||||
{{1, 2, 5, 6, 9, 10, 13, 14}, {3, 4, 7, 8, 11, 12, 15, 16}},
|
||||
/*type=*/TensorType_UINT8);
|
||||
}
|
||||
|
||||
TEST(UnpackOpTest, Uint8VectorToScalar) {
|
||||
Check<uint8_t>(/*axis=*/0, /*input_shape=*/{5},
|
||||
/*input_data=*/{1, 2, 3, 4, 5},
|
||||
@ -280,6 +309,16 @@ TEST(UnpackOpTest, Int8ThreeDimensionsOutputs) {
|
||||
/*type=*/TensorType_INT8);
|
||||
}
|
||||
|
||||
TEST(UnpackOpTest, Int8FiveDimensionsOutputs) {
|
||||
Check<int8_t>(
|
||||
/*axis=*/2, /*input_shape=*/{2, 2, 2, 2, 1},
|
||||
/*input_data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
|
||||
/*exp_output_shape=*/{{2, 2, 2, 1}, {2, 2, 2, 1}},
|
||||
/*exp_output_data=*/
|
||||
{{1, 2, 5, 6, 9, 10, 13, 14}, {3, 4, 7, 8, 11, 12, 15, 16}},
|
||||
/*type=*/TensorType_INT8);
|
||||
}
|
||||
|
||||
TEST(UnpackOpTest, Int8VectorToScalar) {
|
||||
Check<int8_t>(/*axis=*/0, /*input_shape=*/{5},
|
||||
/*input_data=*/{1, 2, 3, 4, 5},
|
||||
@ -344,6 +383,19 @@ TEST(UnpackOpTest, BoolThreeDimensionsOutputs) {
|
||||
/*type=*/TensorType_BOOL);
|
||||
}
|
||||
|
||||
TEST(UnpackOpTest, BoolFiveDimensionsOutputs) {
|
||||
Check<bool>(
|
||||
/*axis=*/2, /*input_shape=*/{2, 2, 2, 2, 1},
|
||||
/*input_data=*/
|
||||
{true, false, true, false, true, false, true, false, true, true, true,
|
||||
true, true, true, true, true},
|
||||
/*exp_output_shape=*/{{2, 2, 2, 1}, {2, 2, 2, 1}},
|
||||
/*exp_output_data=*/
|
||||
{{true, false, true, false, true, true, true, true},
|
||||
{true, false, true, false, true, true, true, true}},
|
||||
/*type=*/TensorType_BOOL);
|
||||
}
|
||||
|
||||
TEST(UnpackOpTest, BoolVectorToScalar) {
|
||||
Check<bool>(/*axis=*/0, /*input_shape=*/{5},
|
||||
/*input_data=*/{true, false, true, false, true},
|
||||
|
Loading…
Reference in New Issue
Block a user