Unpack op now supports vector -> scalar

PiperOrigin-RevId: 260241806
This commit is contained in:
Sachin Joglekar 2019-07-26 17:10:53 -07:00 committed by TensorFlower Gardener
parent 7cb2266fcd
commit 743f73d3d6
2 changed files with 32 additions and 1 deletions

View File

@ -36,7 +36,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TF_LITE_ENSURE(context, NumDimensions(input) <= 4);
TF_LITE_ENSURE(context, NumDimensions(input) > 1);
TF_LITE_ENSURE(context, NumElements(input) > 0);
int axis = data->axis;
if (axis < 0) {
axis += NumDimensions(input);

View File

@ -126,6 +126,13 @@ TEST(UnpackOpTest, FloatThreeDimensionsOutputs) {
/*expected_output_data=*/{{1, 3, 5, 7}, {2, 4, 6, 8}});
}
TEST(UnpackOpTest, FloatVectorToScalar) {
Check<float>(/*axis=*/0, /*input_shape=*/{5},
/*input_data=*/{1, 2, 3, 4, 5},
/*exp_output_shape=*/{{}, {}, {}, {}, {}},
/*exp_output_data=*/{{1}, {2}, {3}, {4}, {5}});
}
// int32 tests.
TEST(UnpackOpTest, IntThreeOutputs) {
Check<int32_t>(/*axis=*/0, /*input_shape=*/{3, 2},
@ -159,6 +166,14 @@ TEST(UnpackOpTest, IntThreeDimensionsOutputs) {
/*type=*/TensorType_INT32);
}
TEST(UnpackOpTest, IntVectorToScalar) {
Check<int32_t>(/*axis=*/0, /*input_shape=*/{5},
/*input_data=*/{1, 2, 3, 4, 5},
/*exp_output_shape=*/{{}, {}, {}, {}, {}},
/*exp_output_data=*/{{1}, {2}, {3}, {4}, {5}},
/*type=*/TensorType_INT32);
}
// uint8 tests.
TEST(UnpackOpTest, Uint8ThreeOutputs) {
Check<uint8_t>(/*axis=*/0, /*input_shape=*/{3, 2},
@ -208,6 +223,14 @@ TEST(UnpackOpTest, Uint8ThreeDimensionsOutputs) {
/*type=*/TensorType_UINT8);
}
TEST(UnpackOpTest, Uint8VectorToScalar) {
Check<uint8_t>(/*axis=*/0, /*input_shape=*/{5},
/*input_data=*/{1, 2, 3, 4, 5},
/*exp_output_shape=*/{{}, {}, {}, {}, {}},
/*exp_output_data=*/{{1}, {2}, {3}, {4}, {5}},
/*type=*/TensorType_UINT8);
}
// int8 tests.
TEST(UnpackOpTest, Int8ThreeOutputs) {
Check<int8_t>(/*axis=*/0, /*input_shape=*/{3, 2},
@ -257,5 +280,13 @@ TEST(UnpackOpTest, Int8ThreeDimensionsOutputs) {
/*type=*/TensorType_INT8);
}
TEST(UnpackOpTest, Int8VectorToScalar) {
Check<int8_t>(/*axis=*/0, /*input_shape=*/{5},
/*input_data=*/{1, 2, 3, 4, 5},
/*exp_output_shape=*/{{}, {}, {}, {}, {}},
/*exp_output_data=*/{{1}, {2}, {3}, {4}, {5}},
/*type=*/TensorType_INT8);
}
} // namespace
} // namespace tflite