Merge pull request #28376 from joshbeal:lite_splitv
PiperOrigin-RevId: 255456586
This commit is contained in:
commit
185d90fc19
@ -124,9 +124,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), op_context.params->num_splits);
|
||||
|
||||
auto input_type = op_context.input->type;
|
||||
TF_LITE_ENSURE(context, input_type == kTfLiteFloat32 ||
|
||||
input_type == kTfLiteUInt8 ||
|
||||
input_type == kTfLiteInt16);
|
||||
TF_LITE_ENSURE(context,
|
||||
input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 ||
|
||||
input_type == kTfLiteInt16 || input_type == kTfLiteInt32 ||
|
||||
input_type == kTfLiteInt64);
|
||||
for (int i = 0; i < NumOutputs(node); ++i) {
|
||||
GetOutput(context, node, i)->type = input_type;
|
||||
}
|
||||
@ -182,6 +183,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_SPLIT_V(int16_t);
|
||||
break;
|
||||
}
|
||||
case kTfLiteInt32: {
|
||||
TF_LITE_SPLIT_V(int32_t);
|
||||
break;
|
||||
}
|
||||
case kTfLiteInt64: {
|
||||
TF_LITE_SPLIT_V(int64_t);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
context->ReportError(context, "Type %s currently not supported.",
|
||||
TfLiteTypeGetName(op_context.input->type));
|
||||
|
@ -200,5 +200,33 @@ TEST(SplitVOpTest, TwoDimensionalInt16) {
|
||||
{{1, 2, 3}, {4, 5, 6}, {7, 8, 9, 10, 11, 12}});
|
||||
}
|
||||
|
||||
TEST(SplitVOpTest, TwoDimensionalInt32) {
|
||||
// Input shape: {4, 3}
|
||||
// size_splits: {1, 1, 2}
|
||||
// axis: 0
|
||||
// We should have 3 outpus with shapes respectively:
|
||||
// output 1 : {1, 3}
|
||||
// output 2 : {1, 3}
|
||||
// output 3 : {2, 3}
|
||||
Check<int32_t, TensorType_INT32>(
|
||||
/*axis=*/0, {4, 3}, {3}, {{1, 3}, {1, 3}, {2, 3}},
|
||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 1, 2},
|
||||
{{1, 2, 3}, {4, 5, 6}, {7, 8, 9, 10, 11, 12}});
|
||||
}
|
||||
|
||||
TEST(SplitVOpTest, TwoDimensionalInt64) {
|
||||
// Input shape: {4, 3}
|
||||
// size_splits: {1, 1, 2}
|
||||
// axis: 0
|
||||
// We should have 3 outpus with shapes respectively:
|
||||
// output 1 : {1, 3}
|
||||
// output 2 : {1, 3}
|
||||
// output 3 : {2, 3}
|
||||
Check<int64_t, TensorType_INT64>(
|
||||
/*axis=*/0, {4, 3}, {3}, {{1, 3}, {1, 3}, {2, 3}},
|
||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 1, 2},
|
||||
{{1, 2, 3}, {4, 5, 6}, {7, 8, 9, 10, 11, 12}});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
Loading…
Reference in New Issue
Block a user