diff --git a/tensorflow/lite/g3doc/_book.yaml b/tensorflow/lite/g3doc/_book.yaml index 22fd564635c..37a3682f473 100644 --- a/tensorflow/lite/g3doc/_book.yaml +++ b/tensorflow/lite/g3doc/_book.yaml @@ -139,7 +139,6 @@ upper_tabs: path: /lite/performance/measurement - title: "Delegates" path: /lite/performance/delegates - status: experimental - title: "GPU delegate" path: /lite/performance/gpu - title: "Advanced GPU" @@ -152,6 +151,9 @@ upper_tabs: - title: "Core ML delegate" path: /lite/performance/coreml_delegate status: experimental + - title: "Implementing a delegate" + path: /lite/performance/implementing_delegate + status: experimental - heading: "Optimize a model" - title: "Overview" diff --git a/tensorflow/lite/g3doc/performance/delegates.md b/tensorflow/lite/g3doc/performance/delegates.md index 6b233075398..b17c9c35fec 100644 --- a/tensorflow/lite/g3doc/performance/delegates.md +++ b/tensorflow/lite/g3doc/performance/delegates.md @@ -1,31 +1,61 @@ -# TensorFlow Lite delegates +# TensorFlow Lite Delegates -Note: Delegate API is still experimental and is subject to change. +## Introduction -## What is a TensorFlow Lite delegate? +**Delegates** enable hardware acceleration of TensorFlow Lite models by +leveraging on-device accelerators such as the GPU and +[Digital Signal Processor (DSP)](https://en.wikipedia.org/wiki/Digital_signal_processor). -A TensorFlow Lite delegate is a way to delegate part or all of graph execution -to another executor. +By default, TensorFlow Lite utilizes CPU kernels that are optimized for the +[ARM Neon](https://developer.arm.com/documentation/dht0002/a/Introducing-NEON/NEON-architecture-overview/NEON-instructions) +instruction set. However, the CPU is a multi-purpose processor that isn't +necessarily optimized for the heavy arithmetic typically found in Machine +Learning models (for example, the matrix math involved in convolution and dense +layers). -## Why should I use delegates? +On the other hand, most modern mobile phones contain chips that are better at +handling these heavy operations. Utilizing them for neural network operations +provides huge benefits in terms of latency and power efficiency. For example, +GPUs can provide upto a +[5x speedup](https://blog.tensorflow.org/2020/08/faster-mobile-gpu-inference-with-opencl.html) +in latency, while the +[Qualcomm® Hexagon DSP](https://developer.qualcomm.com/software/hexagon-dsp-sdk/dsp-processor) +has shown to reduce power consumption upto 75% in our experiments. -Running inference on compute-heavy machine learning models on mobile devices is -resource demanding due to the devices' limited processing and power. +Each of these accelerators have associated APIs that enable custom computations, +such as [OpenCL](https://www.khronos.org/opencl/) or +[OpenGL ES](https://www.khronos.org/opengles/) for mobile GPU and the +[Qualcomm® Hexagon SDK](https://developer.qualcomm.com/software/hexagon-dsp-sdk) +for DSP. Typically, you would have to write a lot of custom code to run a neural +network though these interfaces. Things get even complicated when you consider +that each accelerator has its pros & cons and cannot execute every operation in +a neural network. TensorFlow Lite's Delegate API solves this problem by acting +as a bridge between the TFLite runtime and these lower-level APIs. -Instead of relying on the CPU, some devices have hardware accelerators, such as -GPU or DSP, that allows for better performance and higher energy efficiency. +![runtime with delegates](images/delegate_runtime.png) -## Using the built-in delegates +## Choosing a Delegate -TensorFlow Lite provides the following delegates for hardware acceleration: +TensorFlow Lite supports multiple delegates, each of which is optimized for +certain platform(s) and particular types of models. Usually, there will be +multiple delegates applicable to your use-case, depending on two major criteria: +the *Platform* (Android or iOS?) you target, and the *Model-type* +(floating-point or quantized?) that you are trying to accelerate. + +### Delegates by Platform + +#### Cross-platform (Android & iOS) + +* **GPU delegate** - The GPU delegate can be used on both Android and iOS. It + is optimized to run 32-bit and 16-bit float based models where a GPU is + available. It also supports 8-bit quantized models and provides GPU + performance on par with their float versions. For details on the GPU + delegate, see [TensorFlow Lite on GPU](gpu_advanced.md). For step-by-step + tutorials on using the GPU delegate with Android and iOS, see + [TensorFlow Lite GPU Delegate Tutorial](gpu.md). + +#### Android -* **GPU delegate for cross platform acceleration** - The GPU delegate can be - used on both Android and iOS. It is optimized to run 32-bit and 16-bit float - based models where a GPU is available. It also supports 8-bit quantized - models and provides GPU performance on par with their float versions. For - details on the GPU delegate, see [TensorFlow Lite on GPU](gpu_advanced.md). - For step-by-step tutorials on using the GPU delegate with Android and iOS, - see [TensorFlow Lite GPU Delegate Tutorial](gpu.md). * **NNAPI delegate for newer Android devices** - The NNAPI delegate can be used to accelerate models on Android devices with GPU, DSP and / or NPU available. It is available in Android 8.1 (API 27+) or higher. For an @@ -33,210 +63,188 @@ TensorFlow Lite provides the following delegates for hardware acceleration: practices, see [TensorFlow Lite NNAPI delegate](nnapi.md). * **Hexagon delegate for older Android devices** - The Hexagon delegate can be used to accelerate models on Android devices with Qualcomm Hexagon DSP. It - can be used on devices older version of Android OS that does not fully - support NNAPI. See [TensorFlow Lite Hexagon delegate](hexagon_delegate.md) - for more detail. + can be used on devices running older versions of Android that do not support + NNAPI. See [TensorFlow Lite Hexagon delegate](hexagon_delegate.md) for more + detail. + +#### iOS + * **Core ML delegate for newer iPhones and iPads** - For newer iPhones and iPads where Neural Engine is available, you can use Core ML delegate to - accelerate inference for 32-bit float based models. Neural Engine is - available Apple mobile devices with A12 SoC or higher. For an overview of - the Core ML delegate and step-by-step instructions, see + accelerate inference for 32-bit or 16-bit floating-point models. Neural + Engine is available Apple mobile devices with A12 SoC or higher. For an + overview of the Core ML delegate and step-by-step instructions, see [TensorFlow Lite Core ML delegate](coreml_delegate.md). -## How do delegates work? +### Delegates by model type -Let's say we have a simple model graph such as the following: +Each accelerator is designed with a certain bit-width of data in mind. If you +provide a floating-point model to a delegate that only supports 8-bit quantized +operations (such as the [Hexagon delegate](hexagon_delegate.md)), it will reject +all its operations and the model will run entirely on the CPU. To avoid such +surprises, the table below provides an overview of delegate support based on +model type: -![Original graph](../images/performance/tflite_delegate_graph_1.png "Original Graph") +**Model Type** | **GPU** | **NNAPI** | **Hexagon** | **CoreML** +------------------------------------------------------------------------------------------------------- | ------- | --------- | ----------- | ---------- +Floating-point (32 bit) | Yes | Yes | No | Yes +[Post-training float16 quantization](post_training_float16_quant.ipynb) | Yes | No | No | Yes +[Post-training dynamic range quantization](post_training_quant.ipynb) | Yes | Yes | No | No +[Post-training integer quantization](post_training_integer_quant.ipynb) | Yes | Yes | Yes | No +[Quantization-aware training](http://www.tensorflow.org/model_optimization/guide/quantization/training) | Yes | Yes | Yes | No -If a delegate was provided for specific operations, then TensorFlow Lite will -split the graph into multiple subgraphs where each subgraph will be handled by a -delegate. +### Validating performance -Let's assume that a delegate, `MyDelegate`, has a faster implementation for -Conv2D and Mean operations. The resulting main graph will be updated to look -like below. +The information in this section acts as a rough guideline for shortlisting the +delegates that could improve your application. However, it is important to note +that each delegate has a pre-defined set of operations it supports, and may +perform differently depending on the model and device; for example, the +[NNAPI delegate](nnapi.md) may choose to use Google's Edge-TPU on a Pixel phone +while utilizing a DSP on another device. Therefore, it is usually recommended +that you perform some benchmarking to gauge how useful a delegate is for your +needs. This also helps justify the binary size increase associated with +attaching a delegate to the TensorFlow Lite runtime. -![Graph with delegate](../images/performance/tflite_delegate_graph_2.png "Graph with delegate") +TensorFlow Lite has extensive performance and accuracy-evaluation tooling that +can empower developers to be confident in using delegates in their application. +These tools are discussed in the next section. -Each subgraph that is handled by a delegate will be replaced with a node that -evaluates the subgraph on its invoked call. +## Tools for Evaluation -Depending on the model, the final graph can end up with one node, which means -that all of the graphs were delegated or multiple nodes handled the subgraphs. -In general, you don't want to have multiple subgraphs handled by the delegate, -since each time you switch from delegate to the main graph, there is an overhead -for passing the results from the subgraph to the main graph. It's not always -safe to share memory. +### Latency & memory footprint -## How to add a delegate +TensorFlow Lite’s +[benchmark tool](https://www.tensorflow.org/lite/performance/measurement) can be +used with suitable parameters to estimate model performance, including average +inference latency, initialization overhead, memory footprint, etc. This tool +supports multiple flags to figure out the best delegate configuration for your +model. For instance, `--gpu_backend=gl` can be specified with `--use_gpu` to +measure GPU execution with OpenGL. The complete list of supported delegate +parameters is defined in the +[detailed documentation](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/tools/delegates/README.md#tflite-delegate-registrar). -_Note that the API used below is experimental and is subject to change._ +Here’s an example run for a quantized model with GPU via `adb`: -Based on the previous section, to add a delegate, we need to do the following: - -1. Define a kernel node that is responsible for evaluating the delegate - subgraph. -1. Create an instance of - [TfLiteDelegate](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/c/common.h#L611), - which is responsible for registering the kernel node and claiming the nodes - that the delegate can execute. - -To see it in code, let's define a delegate and call it `MyDelegate`, which can -execute Conv2D and Mean operations faster. - -```c++ -#include "tensorflow/lite/util.h" -#include "tensorflow/lite/builtin_ops.h" -#include "tensorflow/lite/context_util.h" - -// This is where the execution of the operations or whole graph happens. -// The class below has an empty implementation just as a guideline -// on the structure. -class MyDelegate { - public: - // Returns true if my delegate can handle this type of op. - static bool SupportedOp(const TfLiteRegistration* registration) { - switch (registration->builtin_code) { - case kTfLiteBuiltinConv2d: - case kTfLiteBuiltinMean: - return true; - default: - return false; - } - } - - // Any initialization code needed - bool Init() {} - // Any preparation work needed (e.g. allocate buffers) - bool Prepare(TfLiteContext* context, TfLiteNode* node) {} - // Actual running of the delegate subgraph. - bool Invoke(TfLiteContext* context, TfLiteNode* node) {} - // ... Add any other methods needed. -}; - -// Create the TfLiteRegistration for the Kernel node which will replace -// the subgraph in the main TfLite graph. -TfLiteRegistration GetMyDelegateNodeRegistration() { - // This is the registration for the Delegate Node that gets added to - // the TFLite graph instead of the subgraph it replaces. - // It is treated as an OP node. But in our case - // Init will initialize the delegate. - // Invoke will run the delegate graph. - // Prepare for preparing the delegate. - // Free for any cleaning needed by the delegate. - TfLiteRegistration kernel_registration; - kernel_registration.builtin_code = kTfLiteBuiltinDelegate; - kernel_registration.custom_name = "MyDelegate"; - kernel_registration.free = [](TfLiteContext* context, void* buffer) -> void { - delete reinterpret_cast(buffer); - }; - kernel_registration.init = [](TfLiteContext* context, const char* buffer, - size_t) -> void* { - // In the node init phase, initialize MyDelegate instance - const TfLiteDelegateParams* delegate_params = - reinterpret_cast(buffer); - MyDelegate* my_delegate = new MyDelegate; - if (!my_delegate->Init(context, params)) { - return nullptr; - } - return my_delegate; - }; - kernel_registration.invoke = [](TfLiteContext* context, - TfLiteNode* node) -> TfLiteStatus { - MyDelegate* kernel = reinterpret_cast(node->user_data); - return kernel->Invoke(context, node); - }; - kernel_registration.prepare = [](TfLiteContext* context, - TfLiteNode* node) -> TfLiteStatus { - MyDelegate* kernel = reinterpret_cast(node->user_data); - return kernel->Prepare(context, node); - }; - - return kernel_registration; -} - -// TfLiteDelegate methods - -TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) { - // Claim all nodes that can be evaluated by the delegate and ask the - // framework to update the graph with delegate kernel instead. - std::vector supported_nodes; - TfLiteIntArray* plan; - TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan)); - TfLiteNode* node; - TfLiteRegistration* registration; - for (int node_index : TfLiteIntArrayView(plan)) { - TF_LITE_ENSURE_STATUS(context->GetNodeAndRegistration( - context, node_index, &node, ®istration)); - if (MyDelegate::SupportedOp(registration)) { - supported_nodes.push_back(node_index); - } - } - TfLiteRegistration my_delegate_kernel_registration = - GetMyDelegateNodeRegistration(); - - // This call split the graphs into subgraphs, for subgraphs that can be - // handled by the delegate, it will replace it with a - // 'my_delegate_kernel_registration' - TfLiteIntArray* supported_nodes_int_array = - ::tflite::ConvertVectorToTfLiteIntArray(supported_nodes); - auto status = context->ReplaceNodeSubsetsWithDelegateKernels( - context, my_delegate_kernel_registration, - supported_nodes_int_array, delegate); - TfLiteIntArrayFree(supported_nodes_int_array); - return status -} - -void FreeBufferHandle(TfLiteContext* context, TfLiteDelegate* delegate, - TfLiteBufferHandle* handle) { - // Do any cleanups. -} - -TfLiteStatus CopyToBufferHandle(TfLiteContext* context, - TfLiteDelegate* delegate, - TfLiteBufferHandle buffer_handle, - TfLiteTensor* tensor) { - // Copies data from tensor to delegate buffer if needed. - return kTfLiteOk; -} - -TfLiteStatus CopyFromBufferHandle(TfLiteContext* context, - TfLiteDelegate* delegate, - TfLiteBufferHandle buffer_handle, - TfLiteTensor* tensor) { - // Copies the data from delegate buffer into the tensor raw memory. - return kTfLiteOk; -} - -// Caller takes ownership of the returned pointer. -TfLiteDelegate* CreateMyDelegate() { - TfLiteDelegate* delegate = new TfLiteDelegate; - - delegate->data_ = nullptr; - delegate->flags = kTfLiteDelegateFlagsNone; - delegate->Prepare = &DelegatePrepare; - // This cannot be null. - delegate->CopyFromBufferHandle = &CopyFromBufferHandle; - // This can be null. - delegate->CopyToBufferHandle = &CopyToBufferHandle; - // This can be null. - delegate->FreeBufferHandle = &FreeBufferHandle; - - return delegate; -} - - -// To add the delegate you need to call - -auto* my_delegate = CreateMyDelegate(); -if (interpreter->ModifyGraphWithDelegate(my_delegate) != - kTfLiteOk) { - // Handle error -} else { - interpreter->Invoke(); -} -... -// Don't forget to delete your delegate -delete my_delegate; ``` +adb shell /data/local/tmp/benchmark_model \ + --graph=/data/local/tmp/mobilenet_v1_224_quant.tflite \ + --use_gpu=true +``` + +You can download pre-built version of this tool for Android, 64-bit ARM +architecture +[here](https://storage.googleapis.com/tensorflow-nightly-public/prod/tensorflow/release/lite/tools/nightly/latest/android_aarch64_benchmark_model.apk) +([more details](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/benchmark/android)). + +### Accuracy & correctness + +Delegates usually perform computations at a different precision than their CPU +counterparts. As a result, there is an (usually minor) accuracy tradeoff +associated with utilizing a delegate for hardware acceleration. Note that this +isn't *always* true; for example, since the GPU uses floating-point precision to +run quantized models, there might be a slight precision improvement (for e.g., +<1% Top-5 improvement in ILSVRC image classification). + +TensorFlow Lite has two types of tooling to measure how accurately a delegate +behaves for a given model: *Task-Based* and *Task-Agnostic*. All the tools +described in this section support the +[advanced delegation parameters](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/tools/delegates/README.md#tflite-delegate-registrar) +used by the benchmarking tool from the previous section. Note that the +sub-sections below focus on *delegate evaluation* (Does the delegate perform the +same as the CPU?) rather than model evaluation (Is the model itself good for the +task?). + +#### Task-Based Evaluation + +TensorFlow Lite has tools to evaluate correctness on two image-based tasks: + +* [ILSVRC 2012](http://image-net.org/challenges/LSVRC/2012/) (Image + Classification) with + [top-K accuracy](https://en.wikipedia.org/wiki/Evaluation_measures_\(information_retrieval\)#Precision_at_K) + +* [COCO Object Detection (w/ bounding boxes)](https://cocodataset.org/#detection-2020) + with + [mean Average Precision (mAP)](https://en.wikipedia.org/wiki/Evaluation_measures_\(information_retrieval\)#Mean_average_precision) + +Prebuilt binaries of these tools (Android, 64-bit ARM architecture), along with +documentation can be found here: + +* [ImageNet Image Classification](https://storage.googleapis.com/tensorflow-nightly-public/prod/tensorflow/release/lite/tools/nightly/latest/android_aarch64_eval_imagenet_image_classification) + ([More details](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification)) +* [COCO Object Detection](https://storage.googleapis.com/tensorflow-nightly-public/prod/tensorflow/release/lite/tools/nightly/latest/android_aarch64_eval_coco_object_detection) + ([More details](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/evaluation/tasks/coco_object_detection)) + +The example below demonstrates +[image classification evaluation](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification) +with NNAPI utilizing Google's Edge-TPU on a Pixel 4: + +``` +adb shell /data/local/tmp/run_eval \ + --model_file=/data/local/tmp/mobilenet_quant_v1_224.tflite \ + --ground_truth_images_path=/data/local/tmp/ilsvrc_images \ + --ground_truth_labels=/data/local/tmp/ilsvrc_validation_labels.txt \ + --model_output_labels=/data/local/tmp/model_output_labels.txt \ + --output_file_path=/data/local/tmp/accuracy_output.txt \ + --num_images=0 # Run on all images. \ + --use_nnapi=true \ + --nnapi_accelerator_name=google-edgetpu +``` + +The expected output is a list of Top-K metrics from 1 to 10: + +``` +Top-1 Accuracy: 0.733333 +Top-2 Accuracy: 0.826667 +Top-3 Accuracy: 0.856667 +Top-4 Accuracy: 0.87 +Top-5 Accuracy: 0.89 +Top-6 Accuracy: 0.903333 +Top-7 Accuracy: 0.906667 +Top-8 Accuracy: 0.913333 +Top-9 Accuracy: 0.92 +Top-10 Accuracy: 0.923333 +``` + +#### Task-Agnostic Evaluation + +For tasks where there isn't an established on-device evaluation tool, or if you +are experimenting with custom models, TensorFlow Lite has the +[Inference Diff](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/evaluation/tasks/inference_diff) +tool. (Android, 64-bit ARM binary architecture binary +[here](https://storage.googleapis.com/tensorflow-nightly-public/prod/tensorflow/release/lite/tools/nightly/latest/android_aarch64_eval_inference_diff)) + +Inference Diff compares TensorFlow Lite execution (in terms of latency & +output-value deviation) in two settings: + +* Single-threaded CPU Inference +* User-defined Inference - defined by + [these parameters](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/tools/delegates/README.md#tflite-delegate-registrar) + +To do so, the tool generates random Gaussian data and passes it through two +TFLite Interpreters - one running single-threaded CPU kernels, and the other +parametrized by the user's arguments. + +It measures the latency of both, as well as the absolute difference between the +output tensors from each Interpreter, on a per-element basis. + +For a model with a single output tensor, the output might look like this: + +``` +Num evaluation runs: 50 +Reference run latency: avg=84364.2(us), std_dev=12525(us) +Test run latency: avg=7281.64(us), std_dev=2089(us) +OutputDiff[0]: avg_error=1.96277e-05, std_dev=6.95767e-06 +``` + +What this means is that for the output tensor at index `0`, the elements from +the CPU output different from the delegate output by an average of `1.96e-05`. + +Note that interpreting these numbers requires deeper knowledge of the model, and +what each output tensor signifies. If its a simple regression that determines +some sort of score or embedding, the difference should be low (otherwise it's an +error with the delegate). However, outputs like the 'detection class' one from +SSD models is a little harder to interpret. For example, it might show a +difference using this tool, but that may not mean something really wrong with +the delegate: consider two (fake) classes: "TV (ID: 10)", "Monitor (ID:20)" - If +a delegate is slightly off the golden truth and shows monitor instead of TV, the +output diff for this tensor might be something as high as 20-10 = 10. diff --git a/tensorflow/lite/g3doc/performance/images/delegate_runtime.png b/tensorflow/lite/g3doc/performance/images/delegate_runtime.png new file mode 100644 index 00000000000..e229f0fda09 Binary files /dev/null and b/tensorflow/lite/g3doc/performance/images/delegate_runtime.png differ diff --git a/tensorflow/lite/g3doc/performance/implementing_delegate.md b/tensorflow/lite/g3doc/performance/implementing_delegate.md new file mode 100644 index 00000000000..85904cad091 --- /dev/null +++ b/tensorflow/lite/g3doc/performance/implementing_delegate.md @@ -0,0 +1,171 @@ +# Implementing a Delegate + +Note: The API used below is experimental and is subject to change. + +Follow the steps below to add a delegate: + +1. Define a kernel node that is responsible for evaluating the delegate + subgraph. +1. Create an instance of + [TfLiteDelegate](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/c/common.h#L611), + which is responsible for registering the kernel node and claiming the nodes + that the delegate can execute. + +To see it in code, define a delegate `MyDelegate` to execute Conv2D and Mean ops +faster. + +```c++ +#include "tensorflow/lite/util.h" +#include "tensorflow/lite/builtin_ops.h" +#include "tensorflow/lite/context_util.h" + +// This is where the execution of the operations or whole graph happens. +// The class below has an empty implementation just as a guideline +// on the structure. +class MyDelegate { + public: + // Returns true if MyDelegate can handle this type of op. + static bool SupportedOp(const TfLiteRegistration* registration) { + switch (registration->builtin_code) { + case kTfLiteBuiltinConv2d: + case kTfLiteBuiltinMean: + return true; + default: + return false; + } + } + + // Any initialization code needed + bool Init() {} + // Any preparation work needed (e.g. allocate buffers) + bool Prepare(TfLiteContext* context, TfLiteNode* node) {} + // Actual running of the delegate subgraph. + bool Invoke(TfLiteContext* context, TfLiteNode* node) {} + // ... Add any other methods needed. +}; + +// Create the TfLiteRegistration for the Kernel node which will replace +// the subgraph in the main TfLite graph. +TfLiteRegistration GetMyDelegateNodeRegistration() { + // This is the registration for the Delegate Node that gets added to + // the TFLite graph instead of the subgraph it replaces. + // It is treated as an OP node. But in this case + // Init initializes the delegate. + // Invoke runs the delegate graph. + // Prepare prepares the delegate. + // Free performs any memory cleanup needed by the delegate. + TfLiteRegistration kernel_registration; + kernel_registration.builtin_code = kTfLiteBuiltinDelegate; + kernel_registration.custom_name = "MyDelegate"; + kernel_registration.free = [](TfLiteContext* context, void* buffer) -> void { + delete reinterpret_cast(buffer); + }; + kernel_registration.init = [](TfLiteContext* context, const char* buffer, + size_t) -> void* { + // In the node init phase, initialize MyDelegate instance + const TfLiteDelegateParams* delegate_params = + reinterpret_cast(buffer); + MyDelegate* my_delegate = new MyDelegate; + if (!my_delegate->Init(context, params)) { + return nullptr; + } + return my_delegate; + }; + kernel_registration.invoke = [](TfLiteContext* context, + TfLiteNode* node) -> TfLiteStatus { + MyDelegate* kernel = reinterpret_cast(node->user_data); + return kernel->Invoke(context, node); + }; + kernel_registration.prepare = [](TfLiteContext* context, + TfLiteNode* node) -> TfLiteStatus { + MyDelegate* kernel = reinterpret_cast(node->user_data); + return kernel->Prepare(context, node); + }; + + return kernel_registration; +} + +// TfLiteDelegate methods + +TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) { + // Claim all nodes that can be evaluated by the delegate and ask the + // framework to update the graph with delegate kernel instead. + std::vector supported_nodes; + TfLiteIntArray* plan; + TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan)); + TfLiteNode* node; + TfLiteRegistration* registration; + for (int node_index : TfLiteIntArrayView(plan)) { + TF_LITE_ENSURE_STATUS(context->GetNodeAndRegistration( + context, node_index, &node, ®istration)); + if (MyDelegate::SupportedOp(registration)) { + supported_nodes.push_back(node_index); + } + } + TfLiteRegistration my_delegate_kernel_registration = + GetMyDelegateNodeRegistration(); + + // This call split the graphs into subgraphs, for subgraphs that can be + // handled by the delegate, it will replace it with a + // 'my_delegate_kernel_registration' + TfLiteIntArray* supported_nodes_int_array = + ::tflite::ConvertVectorToTfLiteIntArray(supported_nodes); + auto status = context->ReplaceNodeSubsetsWithDelegateKernels( + context, my_delegate_kernel_registration, + supported_nodes_int_array, delegate); + TfLiteIntArrayFree(supported_nodes_int_array); + return status +} + +void FreeBufferHandle(TfLiteContext* context, TfLiteDelegate* delegate, + TfLiteBufferHandle* handle) { + // Do any cleanups. +} + +TfLiteStatus CopyToBufferHandle(TfLiteContext* context, + TfLiteDelegate* delegate, + TfLiteBufferHandle buffer_handle, + TfLiteTensor* tensor) { + // Copies data from tensor to delegate buffer if needed. + return kTfLiteOk; +} + +TfLiteStatus CopyFromBufferHandle(TfLiteContext* context, + TfLiteDelegate* delegate, + TfLiteBufferHandle buffer_handle, + TfLiteTensor* tensor) { + // Copies the data from delegate buffer into the tensor raw memory. + return kTfLiteOk; +} + +// Caller takes ownership of the returned pointer. +TfLiteDelegate* CreateMyDelegate() { + TfLiteDelegate* delegate = new TfLiteDelegate; + + delegate->data_ = nullptr; + delegate->flags = kTfLiteDelegateFlagsNone; + delegate->Prepare = &DelegatePrepare; + // This cannot be null. + delegate->CopyFromBufferHandle = &CopyFromBufferHandle; + // This can be null. + delegate->CopyToBufferHandle = &CopyToBufferHandle; + // This can be null. + delegate->FreeBufferHandle = &FreeBufferHandle; + + return delegate; +} + + +// To add the delegate you need to call + +auto* my_delegate = CreateMyDelegate(); +if (interpreter->ModifyGraphWithDelegate(my_delegate) != + kTfLiteOk) { + // Handle error +} else { + interpreter->Invoke(); +} +... +// Don't forget to delete your delegate +delete my_delegate; +```