Add split support to the NNAPI delegate
PiperOrigin-RevId: 256448757
This commit is contained in:
parent
1a216e8946
commit
a18d0fc601
@ -1920,6 +1920,28 @@ class NNAPIDelegateKernel {
|
||||
};
|
||||
}
|
||||
} break;
|
||||
case kTfLiteBuiltinSplit: {
|
||||
// Tensor indices: split_dim: 0, value: 1
|
||||
const TfLiteTensor& axis = context->tensors[node->inputs->data[0]];
|
||||
const TfLiteTensor& input = context->tensors[node->inputs->data[1]];
|
||||
if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
|
||||
(input.type == kTfLiteFloat32 || input.type == kTfLiteUInt8 ||
|
||||
input.type == kTfLiteInt32) &&
|
||||
(axis.type == kTfLiteInt32 &&
|
||||
axis.allocation_type == kTfLiteMmapRo)) {
|
||||
return [](const NNAPIOpMappingArgs& mapping_args)
|
||||
-> ANeuralNetworksOperationType {
|
||||
const TfLiteTensor& axis =
|
||||
mapping_args.context
|
||||
->tensors[mapping_args.node->inputs->data[0]];
|
||||
auto builtin = reinterpret_cast<TfLiteSplitParams*>(
|
||||
mapping_args.node->builtin_data);
|
||||
mapping_args.builder->AddScalarInt32Operand(*axis.data.i32);
|
||||
mapping_args.builder->AddScalarInt32Operand(builtin->num_splits);
|
||||
return ANEURALNETWORKS_SPLIT;
|
||||
};
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
// All other operators are not mapped.
|
||||
return nullptr;
|
||||
@ -2335,6 +2357,14 @@ class NNAPIDelegateKernel {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if ((reg->builtin_code == kTfLiteBuiltinSplit) &&
|
||||
(input_index == node->inputs->data[0])) {
|
||||
// Skip the axis input tensor; it will be added as a scalar operand
|
||||
// by the Map() mapping.
|
||||
continue;
|
||||
}
|
||||
|
||||
// Pad and Padv2 have an optional parameter for a pad value which has
|
||||
// to be converted to a scalar type in NN API.
|
||||
if ((reg->builtin_code == kTfLiteBuiltinPadv2 ||
|
||||
|
@ -1231,6 +1231,7 @@ cc_test(
|
||||
name = "split_test",
|
||||
size = "small",
|
||||
srcs = ["split_test.cc"],
|
||||
tags = ["tflite_nnapi"],
|
||||
deps = [
|
||||
":builtin_ops",
|
||||
":test_main",
|
||||
|
@ -115,6 +115,7 @@ enum {
|
||||
ANEURALNETWORKS_RSQRT = 83,
|
||||
ANEURALNETWORKS_SELECT = 84,
|
||||
ANEURALNETWORKS_SIN = 85,
|
||||
ANEURALNETWORKS_SPLIT = 87,
|
||||
ANEURALNETWORKS_SQRT = 88,
|
||||
ANEURALNETWORKS_TILE = 89,
|
||||
ANEURALNETWORKS_TOPK_V2 = 90,
|
||||
|
Loading…
Reference in New Issue
Block a user