Add split support to the NNAPI delegate

PiperOrigin-RevId: 256448757
This commit is contained in:
Jared Duke 2019-07-03 15:25:45 -07:00 committed by TensorFlower Gardener
parent 1a216e8946
commit a18d0fc601
3 changed files with 32 additions and 0 deletions

View File

@ -1920,6 +1920,28 @@ class NNAPIDelegateKernel {
};
}
} break;
case kTfLiteBuiltinSplit: {
// Tensor indices: split_dim: 0, value: 1
const TfLiteTensor& axis = context->tensors[node->inputs->data[0]];
const TfLiteTensor& input = context->tensors[node->inputs->data[1]];
if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
(input.type == kTfLiteFloat32 || input.type == kTfLiteUInt8 ||
input.type == kTfLiteInt32) &&
(axis.type == kTfLiteInt32 &&
axis.allocation_type == kTfLiteMmapRo)) {
return [](const NNAPIOpMappingArgs& mapping_args)
-> ANeuralNetworksOperationType {
const TfLiteTensor& axis =
mapping_args.context
->tensors[mapping_args.node->inputs->data[0]];
auto builtin = reinterpret_cast<TfLiteSplitParams*>(
mapping_args.node->builtin_data);
mapping_args.builder->AddScalarInt32Operand(*axis.data.i32);
mapping_args.builder->AddScalarInt32Operand(builtin->num_splits);
return ANEURALNETWORKS_SPLIT;
};
}
} break;
default:
// All other operators are not mapped.
return nullptr;
@ -2335,6 +2357,14 @@ class NNAPIDelegateKernel {
continue;
}
}
if ((reg->builtin_code == kTfLiteBuiltinSplit) &&
(input_index == node->inputs->data[0])) {
// Skip the axis input tensor; it will be added as a scalar operand
// by the Map() mapping.
continue;
}
// Pad and Padv2 have an optional parameter for a pad value which has
// to be converted to a scalar type in NN API.
if ((reg->builtin_code == kTfLiteBuiltinPadv2 ||

View File

@ -1231,6 +1231,7 @@ cc_test(
name = "split_test",
size = "small",
srcs = ["split_test.cc"],
tags = ["tflite_nnapi"],
deps = [
":builtin_ops",
":test_main",

View File

@ -115,6 +115,7 @@ enum {
ANEURALNETWORKS_RSQRT = 83,
ANEURALNETWORKS_SELECT = 84,
ANEURALNETWORKS_SIN = 85,
ANEURALNETWORKS_SPLIT = 87,
ANEURALNETWORKS_SQRT = 88,
ANEURALNETWORKS_TILE = 89,
ANEURALNETWORKS_TOPK_V2 = 90,