From a18d0fc601c0a2b83878cb307471edea0606ac7f Mon Sep 17 00:00:00 2001 From: Jared Duke Date: Wed, 3 Jul 2019 15:25:45 -0700 Subject: [PATCH] Add split support to the NNAPI delegate PiperOrigin-RevId: 256448757 --- .../lite/delegates/nnapi/nnapi_delegate.cc | 30 +++++++++++++++++++ tensorflow/lite/kernels/BUILD | 1 + tensorflow/lite/nnapi/NeuralNetworksTypes.h | 1 + 3 files changed, 32 insertions(+) diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc index 0aa856adb8b..60644409402 100644 --- a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc +++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc @@ -1920,6 +1920,28 @@ class NNAPIDelegateKernel { }; } } break; + case kTfLiteBuiltinSplit: { + // Tensor indices: split_dim: 0, value: 1 + const TfLiteTensor& axis = context->tensors[node->inputs->data[0]]; + const TfLiteTensor& input = context->tensors[node->inputs->data[1]]; + if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 && + (input.type == kTfLiteFloat32 || input.type == kTfLiteUInt8 || + input.type == kTfLiteInt32) && + (axis.type == kTfLiteInt32 && + axis.allocation_type == kTfLiteMmapRo)) { + return [](const NNAPIOpMappingArgs& mapping_args) + -> ANeuralNetworksOperationType { + const TfLiteTensor& axis = + mapping_args.context + ->tensors[mapping_args.node->inputs->data[0]]; + auto builtin = reinterpret_cast( + mapping_args.node->builtin_data); + mapping_args.builder->AddScalarInt32Operand(*axis.data.i32); + mapping_args.builder->AddScalarInt32Operand(builtin->num_splits); + return ANEURALNETWORKS_SPLIT; + }; + } + } break; default: // All other operators are not mapped. return nullptr; @@ -2335,6 +2357,14 @@ class NNAPIDelegateKernel { continue; } } + + if ((reg->builtin_code == kTfLiteBuiltinSplit) && + (input_index == node->inputs->data[0])) { + // Skip the axis input tensor; it will be added as a scalar operand + // by the Map() mapping. + continue; + } + // Pad and Padv2 have an optional parameter for a pad value which has // to be converted to a scalar type in NN API. if ((reg->builtin_code == kTfLiteBuiltinPadv2 || diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index 7fee4bf57d6..ee658656c0c 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -1231,6 +1231,7 @@ cc_test( name = "split_test", size = "small", srcs = ["split_test.cc"], + tags = ["tflite_nnapi"], deps = [ ":builtin_ops", ":test_main", diff --git a/tensorflow/lite/nnapi/NeuralNetworksTypes.h b/tensorflow/lite/nnapi/NeuralNetworksTypes.h index 2c641428cbc..882d40bf666 100644 --- a/tensorflow/lite/nnapi/NeuralNetworksTypes.h +++ b/tensorflow/lite/nnapi/NeuralNetworksTypes.h @@ -115,6 +115,7 @@ enum { ANEURALNETWORKS_RSQRT = 83, ANEURALNETWORKS_SELECT = 84, ANEURALNETWORKS_SIN = 85, + ANEURALNETWORKS_SPLIT = 87, ANEURALNETWORKS_SQRT = 88, ANEURALNETWORKS_TILE = 89, ANEURALNETWORKS_TOPK_V2 = 90,