Merge pull request #28376 from joshbeal:lite_splitv

PiperOrigin-RevId: 255456586
This commit is contained in:
TensorFlower Gardener 2019-06-27 15:25:43 -07:00
commit 185d90fc19
2 changed files with 40 additions and 3 deletions

View File

@ -124,9 +124,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumOutputs(node), op_context.params->num_splits);
auto input_type = op_context.input->type;
TF_LITE_ENSURE(context, input_type == kTfLiteFloat32 ||
input_type == kTfLiteUInt8 ||
input_type == kTfLiteInt16);
TF_LITE_ENSURE(context,
input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 ||
input_type == kTfLiteInt16 || input_type == kTfLiteInt32 ||
input_type == kTfLiteInt64);
for (int i = 0; i < NumOutputs(node); ++i) {
GetOutput(context, node, i)->type = input_type;
}
@ -182,6 +183,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_SPLIT_V(int16_t);
break;
}
case kTfLiteInt32: {
TF_LITE_SPLIT_V(int32_t);
break;
}
case kTfLiteInt64: {
TF_LITE_SPLIT_V(int64_t);
break;
}
default:
context->ReportError(context, "Type %s currently not supported.",
TfLiteTypeGetName(op_context.input->type));

View File

@ -200,5 +200,33 @@ TEST(SplitVOpTest, TwoDimensionalInt16) {
{{1, 2, 3}, {4, 5, 6}, {7, 8, 9, 10, 11, 12}});
}
TEST(SplitVOpTest, TwoDimensionalInt32) {
// Input shape: {4, 3}
// size_splits: {1, 1, 2}
// axis: 0
// We should have 3 outpus with shapes respectively:
// output 1 : {1, 3}
// output 2 : {1, 3}
// output 3 : {2, 3}
Check<int32_t, TensorType_INT32>(
/*axis=*/0, {4, 3}, {3}, {{1, 3}, {1, 3}, {2, 3}},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 1, 2},
{{1, 2, 3}, {4, 5, 6}, {7, 8, 9, 10, 11, 12}});
}
TEST(SplitVOpTest, TwoDimensionalInt64) {
// Input shape: {4, 3}
// size_splits: {1, 1, 2}
// axis: 0
// We should have 3 outpus with shapes respectively:
// output 1 : {1, 3}
// output 2 : {1, 3}
// output 3 : {2, 3}
Check<int64_t, TensorType_INT64>(
/*axis=*/0, {4, 3}, {3}, {{1, 3}, {1, 3}, {2, 3}},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 1, 2},
{{1, 2, 3}, {4, 5, 6}, {7, 8, 9, 10, 11, 12}});
}
} // namespace
} // namespace tflite