Unpack op now supports vector -> scalar
PiperOrigin-RevId: 260241806
This commit is contained in:
parent
7cb2266fcd
commit
743f73d3d6
@ -36,7 +36,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
|
|
||||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||||
TF_LITE_ENSURE(context, NumDimensions(input) <= 4);
|
TF_LITE_ENSURE(context, NumDimensions(input) <= 4);
|
||||||
TF_LITE_ENSURE(context, NumDimensions(input) > 1);
|
TF_LITE_ENSURE(context, NumElements(input) > 0);
|
||||||
int axis = data->axis;
|
int axis = data->axis;
|
||||||
if (axis < 0) {
|
if (axis < 0) {
|
||||||
axis += NumDimensions(input);
|
axis += NumDimensions(input);
|
||||||
|
@ -126,6 +126,13 @@ TEST(UnpackOpTest, FloatThreeDimensionsOutputs) {
|
|||||||
/*expected_output_data=*/{{1, 3, 5, 7}, {2, 4, 6, 8}});
|
/*expected_output_data=*/{{1, 3, 5, 7}, {2, 4, 6, 8}});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(UnpackOpTest, FloatVectorToScalar) {
|
||||||
|
Check<float>(/*axis=*/0, /*input_shape=*/{5},
|
||||||
|
/*input_data=*/{1, 2, 3, 4, 5},
|
||||||
|
/*exp_output_shape=*/{{}, {}, {}, {}, {}},
|
||||||
|
/*exp_output_data=*/{{1}, {2}, {3}, {4}, {5}});
|
||||||
|
}
|
||||||
|
|
||||||
// int32 tests.
|
// int32 tests.
|
||||||
TEST(UnpackOpTest, IntThreeOutputs) {
|
TEST(UnpackOpTest, IntThreeOutputs) {
|
||||||
Check<int32_t>(/*axis=*/0, /*input_shape=*/{3, 2},
|
Check<int32_t>(/*axis=*/0, /*input_shape=*/{3, 2},
|
||||||
@ -159,6 +166,14 @@ TEST(UnpackOpTest, IntThreeDimensionsOutputs) {
|
|||||||
/*type=*/TensorType_INT32);
|
/*type=*/TensorType_INT32);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(UnpackOpTest, IntVectorToScalar) {
|
||||||
|
Check<int32_t>(/*axis=*/0, /*input_shape=*/{5},
|
||||||
|
/*input_data=*/{1, 2, 3, 4, 5},
|
||||||
|
/*exp_output_shape=*/{{}, {}, {}, {}, {}},
|
||||||
|
/*exp_output_data=*/{{1}, {2}, {3}, {4}, {5}},
|
||||||
|
/*type=*/TensorType_INT32);
|
||||||
|
}
|
||||||
|
|
||||||
// uint8 tests.
|
// uint8 tests.
|
||||||
TEST(UnpackOpTest, Uint8ThreeOutputs) {
|
TEST(UnpackOpTest, Uint8ThreeOutputs) {
|
||||||
Check<uint8_t>(/*axis=*/0, /*input_shape=*/{3, 2},
|
Check<uint8_t>(/*axis=*/0, /*input_shape=*/{3, 2},
|
||||||
@ -208,6 +223,14 @@ TEST(UnpackOpTest, Uint8ThreeDimensionsOutputs) {
|
|||||||
/*type=*/TensorType_UINT8);
|
/*type=*/TensorType_UINT8);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(UnpackOpTest, Uint8VectorToScalar) {
|
||||||
|
Check<uint8_t>(/*axis=*/0, /*input_shape=*/{5},
|
||||||
|
/*input_data=*/{1, 2, 3, 4, 5},
|
||||||
|
/*exp_output_shape=*/{{}, {}, {}, {}, {}},
|
||||||
|
/*exp_output_data=*/{{1}, {2}, {3}, {4}, {5}},
|
||||||
|
/*type=*/TensorType_UINT8);
|
||||||
|
}
|
||||||
|
|
||||||
// int8 tests.
|
// int8 tests.
|
||||||
TEST(UnpackOpTest, Int8ThreeOutputs) {
|
TEST(UnpackOpTest, Int8ThreeOutputs) {
|
||||||
Check<int8_t>(/*axis=*/0, /*input_shape=*/{3, 2},
|
Check<int8_t>(/*axis=*/0, /*input_shape=*/{3, 2},
|
||||||
@ -257,5 +280,13 @@ TEST(UnpackOpTest, Int8ThreeDimensionsOutputs) {
|
|||||||
/*type=*/TensorType_INT8);
|
/*type=*/TensorType_INT8);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(UnpackOpTest, Int8VectorToScalar) {
|
||||||
|
Check<int8_t>(/*axis=*/0, /*input_shape=*/{5},
|
||||||
|
/*input_data=*/{1, 2, 3, 4, 5},
|
||||||
|
/*exp_output_shape=*/{{}, {}, {}, {}, {}},
|
||||||
|
/*exp_output_data=*/{{1}, {2}, {3}, {4}, {5}},
|
||||||
|
/*type=*/TensorType_INT8);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
Loading…
Reference in New Issue
Block a user