Unpack op now supports vector -> scalar
PiperOrigin-RevId: 260241806
This commit is contained in:
parent
7cb2266fcd
commit
743f73d3d6
@ -36,7 +36,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TF_LITE_ENSURE(context, NumDimensions(input) <= 4);
|
||||
TF_LITE_ENSURE(context, NumDimensions(input) > 1);
|
||||
TF_LITE_ENSURE(context, NumElements(input) > 0);
|
||||
int axis = data->axis;
|
||||
if (axis < 0) {
|
||||
axis += NumDimensions(input);
|
||||
|
@ -126,6 +126,13 @@ TEST(UnpackOpTest, FloatThreeDimensionsOutputs) {
|
||||
/*expected_output_data=*/{{1, 3, 5, 7}, {2, 4, 6, 8}});
|
||||
}
|
||||
|
||||
TEST(UnpackOpTest, FloatVectorToScalar) {
|
||||
Check<float>(/*axis=*/0, /*input_shape=*/{5},
|
||||
/*input_data=*/{1, 2, 3, 4, 5},
|
||||
/*exp_output_shape=*/{{}, {}, {}, {}, {}},
|
||||
/*exp_output_data=*/{{1}, {2}, {3}, {4}, {5}});
|
||||
}
|
||||
|
||||
// int32 tests.
|
||||
TEST(UnpackOpTest, IntThreeOutputs) {
|
||||
Check<int32_t>(/*axis=*/0, /*input_shape=*/{3, 2},
|
||||
@ -159,6 +166,14 @@ TEST(UnpackOpTest, IntThreeDimensionsOutputs) {
|
||||
/*type=*/TensorType_INT32);
|
||||
}
|
||||
|
||||
TEST(UnpackOpTest, IntVectorToScalar) {
|
||||
Check<int32_t>(/*axis=*/0, /*input_shape=*/{5},
|
||||
/*input_data=*/{1, 2, 3, 4, 5},
|
||||
/*exp_output_shape=*/{{}, {}, {}, {}, {}},
|
||||
/*exp_output_data=*/{{1}, {2}, {3}, {4}, {5}},
|
||||
/*type=*/TensorType_INT32);
|
||||
}
|
||||
|
||||
// uint8 tests.
|
||||
TEST(UnpackOpTest, Uint8ThreeOutputs) {
|
||||
Check<uint8_t>(/*axis=*/0, /*input_shape=*/{3, 2},
|
||||
@ -208,6 +223,14 @@ TEST(UnpackOpTest, Uint8ThreeDimensionsOutputs) {
|
||||
/*type=*/TensorType_UINT8);
|
||||
}
|
||||
|
||||
TEST(UnpackOpTest, Uint8VectorToScalar) {
|
||||
Check<uint8_t>(/*axis=*/0, /*input_shape=*/{5},
|
||||
/*input_data=*/{1, 2, 3, 4, 5},
|
||||
/*exp_output_shape=*/{{}, {}, {}, {}, {}},
|
||||
/*exp_output_data=*/{{1}, {2}, {3}, {4}, {5}},
|
||||
/*type=*/TensorType_UINT8);
|
||||
}
|
||||
|
||||
// int8 tests.
|
||||
TEST(UnpackOpTest, Int8ThreeOutputs) {
|
||||
Check<int8_t>(/*axis=*/0, /*input_shape=*/{3, 2},
|
||||
@ -257,5 +280,13 @@ TEST(UnpackOpTest, Int8ThreeDimensionsOutputs) {
|
||||
/*type=*/TensorType_INT8);
|
||||
}
|
||||
|
||||
TEST(UnpackOpTest, Int8VectorToScalar) {
|
||||
Check<int8_t>(/*axis=*/0, /*input_shape=*/{5},
|
||||
/*input_data=*/{1, 2, 3, 4, 5},
|
||||
/*exp_output_shape=*/{{}, {}, {}, {}, {}},
|
||||
/*exp_output_data=*/{{1}, {2}, {3}, {4}, {5}},
|
||||
/*type=*/TensorType_INT8);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
Loading…
Reference in New Issue
Block a user