Update detection intro page with better details on support & tooling. Also add an 'implementing delegates' page that will be polished later with simpler instructions on delegate authoring.

PiperOrigin-RevId: 338165590
Change-Id: Ifd6e0fd261e157c41321f7dacc2fb16f40bd7769
This commit is contained in:
Sachin Joglekar 2020-10-20 16:59:31 -07:00 committed by TensorFlower Gardener
parent 2638bb9920
commit 19c8b34112
4 changed files with 393 additions and 212 deletions

View File

@ -139,7 +139,6 @@ upper_tabs:
path: /lite/performance/measurement
- title: "Delegates"
path: /lite/performance/delegates
status: experimental
- title: "GPU delegate"
path: /lite/performance/gpu
- title: "Advanced GPU"
@ -152,6 +151,9 @@ upper_tabs:
- title: "Core ML delegate"
path: /lite/performance/coreml_delegate
status: experimental
- title: "Implementing a delegate"
path: /lite/performance/implementing_delegate
status: experimental
- heading: "Optimize a model"
- title: "Overview"

View File

@ -1,31 +1,61 @@
# TensorFlow Lite delegates
# TensorFlow Lite Delegates
Note: Delegate API is still experimental and is subject to change.
## Introduction
## What is a TensorFlow Lite delegate?
**Delegates** enable hardware acceleration of TensorFlow Lite models by
leveraging on-device accelerators such as the GPU and
[Digital Signal Processor (DSP)](https://en.wikipedia.org/wiki/Digital_signal_processor).
A TensorFlow Lite delegate is a way to delegate part or all of graph execution
to another executor.
By default, TensorFlow Lite utilizes CPU kernels that are optimized for the
[ARM Neon](https://developer.arm.com/documentation/dht0002/a/Introducing-NEON/NEON-architecture-overview/NEON-instructions)
instruction set. However, the CPU is a multi-purpose processor that isn't
necessarily optimized for the heavy arithmetic typically found in Machine
Learning models (for example, the matrix math involved in convolution and dense
layers).
## Why should I use delegates?
On the other hand, most modern mobile phones contain chips that are better at
handling these heavy operations. Utilizing them for neural network operations
provides huge benefits in terms of latency and power efficiency. For example,
GPUs can provide upto a
[5x speedup](https://blog.tensorflow.org/2020/08/faster-mobile-gpu-inference-with-opencl.html)
in latency, while the
[Qualcomm® Hexagon DSP](https://developer.qualcomm.com/software/hexagon-dsp-sdk/dsp-processor)
has shown to reduce power consumption upto 75% in our experiments.
Running inference on compute-heavy machine learning models on mobile devices is
resource demanding due to the devices' limited processing and power.
Each of these accelerators have associated APIs that enable custom computations,
such as [OpenCL](https://www.khronos.org/opencl/) or
[OpenGL ES](https://www.khronos.org/opengles/) for mobile GPU and the
[Qualcomm® Hexagon SDK](https://developer.qualcomm.com/software/hexagon-dsp-sdk)
for DSP. Typically, you would have to write a lot of custom code to run a neural
network though these interfaces. Things get even complicated when you consider
that each accelerator has its pros & cons and cannot execute every operation in
a neural network. TensorFlow Lite's Delegate API solves this problem by acting
as a bridge between the TFLite runtime and these lower-level APIs.
Instead of relying on the CPU, some devices have hardware accelerators, such as
GPU or DSP, that allows for better performance and higher energy efficiency.
![runtime with delegates](images/delegate_runtime.png)
## Using the built-in delegates
## Choosing a Delegate
TensorFlow Lite provides the following delegates for hardware acceleration:
TensorFlow Lite supports multiple delegates, each of which is optimized for
certain platform(s) and particular types of models. Usually, there will be
multiple delegates applicable to your use-case, depending on two major criteria:
the *Platform* (Android or iOS?) you target, and the *Model-type*
(floating-point or quantized?) that you are trying to accelerate.
### Delegates by Platform
#### Cross-platform (Android & iOS)
* **GPU delegate** - The GPU delegate can be used on both Android and iOS. It
is optimized to run 32-bit and 16-bit float based models where a GPU is
available. It also supports 8-bit quantized models and provides GPU
performance on par with their float versions. For details on the GPU
delegate, see [TensorFlow Lite on GPU](gpu_advanced.md). For step-by-step
tutorials on using the GPU delegate with Android and iOS, see
[TensorFlow Lite GPU Delegate Tutorial](gpu.md).
#### Android
* **GPU delegate for cross platform acceleration** - The GPU delegate can be
used on both Android and iOS. It is optimized to run 32-bit and 16-bit float
based models where a GPU is available. It also supports 8-bit quantized
models and provides GPU performance on par with their float versions. For
details on the GPU delegate, see [TensorFlow Lite on GPU](gpu_advanced.md).
For step-by-step tutorials on using the GPU delegate with Android and iOS,
see [TensorFlow Lite GPU Delegate Tutorial](gpu.md).
* **NNAPI delegate for newer Android devices** - The NNAPI delegate can be
used to accelerate models on Android devices with GPU, DSP and / or NPU
available. It is available in Android 8.1 (API 27+) or higher. For an
@ -33,210 +63,188 @@ TensorFlow Lite provides the following delegates for hardware acceleration:
practices, see [TensorFlow Lite NNAPI delegate](nnapi.md).
* **Hexagon delegate for older Android devices** - The Hexagon delegate can be
used to accelerate models on Android devices with Qualcomm Hexagon DSP. It
can be used on devices older version of Android OS that does not fully
support NNAPI. See [TensorFlow Lite Hexagon delegate](hexagon_delegate.md)
for more detail.
can be used on devices running older versions of Android that do not support
NNAPI. See [TensorFlow Lite Hexagon delegate](hexagon_delegate.md) for more
detail.
#### iOS
* **Core ML delegate for newer iPhones and iPads** - For newer iPhones and
iPads where Neural Engine is available, you can use Core ML delegate to
accelerate inference for 32-bit float based models. Neural Engine is
available Apple mobile devices with A12 SoC or higher. For an overview of
the Core ML delegate and step-by-step instructions, see
accelerate inference for 32-bit or 16-bit floating-point models. Neural
Engine is available Apple mobile devices with A12 SoC or higher. For an
overview of the Core ML delegate and step-by-step instructions, see
[TensorFlow Lite Core ML delegate](coreml_delegate.md).
## How do delegates work?
### Delegates by model type
Let's say we have a simple model graph such as the following:
Each accelerator is designed with a certain bit-width of data in mind. If you
provide a floating-point model to a delegate that only supports 8-bit quantized
operations (such as the [Hexagon delegate](hexagon_delegate.md)), it will reject
all its operations and the model will run entirely on the CPU. To avoid such
surprises, the table below provides an overview of delegate support based on
model type:
![Original graph](../images/performance/tflite_delegate_graph_1.png "Original Graph")
**Model Type** | **GPU** | **NNAPI** | **Hexagon** | **CoreML**
------------------------------------------------------------------------------------------------------- | ------- | --------- | ----------- | ----------
Floating-point (32 bit) | Yes | Yes | No | Yes
[Post-training float16 quantization](post_training_float16_quant.ipynb) | Yes | No | No | Yes
[Post-training dynamic range quantization](post_training_quant.ipynb) | Yes | Yes | No | No
[Post-training integer quantization](post_training_integer_quant.ipynb) | Yes | Yes | Yes | No
[Quantization-aware training](http://www.tensorflow.org/model_optimization/guide/quantization/training) | Yes | Yes | Yes | No
If a delegate was provided for specific operations, then TensorFlow Lite will
split the graph into multiple subgraphs where each subgraph will be handled by a
delegate.
### Validating performance
Let's assume that a delegate, `MyDelegate`, has a faster implementation for
Conv2D and Mean operations. The resulting main graph will be updated to look
like below.
The information in this section acts as a rough guideline for shortlisting the
delegates that could improve your application. However, it is important to note
that each delegate has a pre-defined set of operations it supports, and may
perform differently depending on the model and device; for example, the
[NNAPI delegate](nnapi.md) may choose to use Google's Edge-TPU on a Pixel phone
while utilizing a DSP on another device. Therefore, it is usually recommended
that you perform some benchmarking to gauge how useful a delegate is for your
needs. This also helps justify the binary size increase associated with
attaching a delegate to the TensorFlow Lite runtime.
![Graph with delegate](../images/performance/tflite_delegate_graph_2.png "Graph with delegate")
TensorFlow Lite has extensive performance and accuracy-evaluation tooling that
can empower developers to be confident in using delegates in their application.
These tools are discussed in the next section.
Each subgraph that is handled by a delegate will be replaced with a node that
evaluates the subgraph on its invoked call.
## Tools for Evaluation
Depending on the model, the final graph can end up with one node, which means
that all of the graphs were delegated or multiple nodes handled the subgraphs.
In general, you don't want to have multiple subgraphs handled by the delegate,
since each time you switch from delegate to the main graph, there is an overhead
for passing the results from the subgraph to the main graph. It's not always
safe to share memory.
### Latency & memory footprint
## How to add a delegate
TensorFlow Lites
[benchmark tool](https://www.tensorflow.org/lite/performance/measurement) can be
used with suitable parameters to estimate model performance, including average
inference latency, initialization overhead, memory footprint, etc. This tool
supports multiple flags to figure out the best delegate configuration for your
model. For instance, `--gpu_backend=gl` can be specified with `--use_gpu` to
measure GPU execution with OpenGL. The complete list of supported delegate
parameters is defined in the
[detailed documentation](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/tools/delegates/README.md#tflite-delegate-registrar).
_Note that the API used below is experimental and is subject to change._
Heres an example run for a quantized model with GPU via `adb`:
Based on the previous section, to add a delegate, we need to do the following:
1. Define a kernel node that is responsible for evaluating the delegate
subgraph.
1. Create an instance of
[TfLiteDelegate](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/c/common.h#L611),
which is responsible for registering the kernel node and claiming the nodes
that the delegate can execute.
To see it in code, let's define a delegate and call it `MyDelegate`, which can
execute Conv2D and Mean operations faster.
```c++
#include "tensorflow/lite/util.h"
#include "tensorflow/lite/builtin_ops.h"
#include "tensorflow/lite/context_util.h"
// This is where the execution of the operations or whole graph happens.
// The class below has an empty implementation just as a guideline
// on the structure.
class MyDelegate {
public:
// Returns true if my delegate can handle this type of op.
static bool SupportedOp(const TfLiteRegistration* registration) {
switch (registration->builtin_code) {
case kTfLiteBuiltinConv2d:
case kTfLiteBuiltinMean:
return true;
default:
return false;
}
}
// Any initialization code needed
bool Init() {}
// Any preparation work needed (e.g. allocate buffers)
bool Prepare(TfLiteContext* context, TfLiteNode* node) {}
// Actual running of the delegate subgraph.
bool Invoke(TfLiteContext* context, TfLiteNode* node) {}
// ... Add any other methods needed.
};
// Create the TfLiteRegistration for the Kernel node which will replace
// the subgraph in the main TfLite graph.
TfLiteRegistration GetMyDelegateNodeRegistration() {
// This is the registration for the Delegate Node that gets added to
// the TFLite graph instead of the subgraph it replaces.
// It is treated as an OP node. But in our case
// Init will initialize the delegate.
// Invoke will run the delegate graph.
// Prepare for preparing the delegate.
// Free for any cleaning needed by the delegate.
TfLiteRegistration kernel_registration;
kernel_registration.builtin_code = kTfLiteBuiltinDelegate;
kernel_registration.custom_name = "MyDelegate";
kernel_registration.free = [](TfLiteContext* context, void* buffer) -> void {
delete reinterpret_cast<MyDelegate*>(buffer);
};
kernel_registration.init = [](TfLiteContext* context, const char* buffer,
size_t) -> void* {
// In the node init phase, initialize MyDelegate instance
const TfLiteDelegateParams* delegate_params =
reinterpret_cast<const TfLiteDelegateParams*>(buffer);
MyDelegate* my_delegate = new MyDelegate;
if (!my_delegate->Init(context, params)) {
return nullptr;
}
return my_delegate;
};
kernel_registration.invoke = [](TfLiteContext* context,
TfLiteNode* node) -> TfLiteStatus {
MyDelegate* kernel = reinterpret_cast<MyDelegate*>(node->user_data);
return kernel->Invoke(context, node);
};
kernel_registration.prepare = [](TfLiteContext* context,
TfLiteNode* node) -> TfLiteStatus {
MyDelegate* kernel = reinterpret_cast<MyDelegate*>(node->user_data);
return kernel->Prepare(context, node);
};
return kernel_registration;
}
// TfLiteDelegate methods
TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) {
// Claim all nodes that can be evaluated by the delegate and ask the
// framework to update the graph with delegate kernel instead.
std::vector<int> supported_nodes;
TfLiteIntArray* plan;
TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan));
TfLiteNode* node;
TfLiteRegistration* registration;
for (int node_index : TfLiteIntArrayView(plan)) {
TF_LITE_ENSURE_STATUS(context->GetNodeAndRegistration(
context, node_index, &node, &registration));
if (MyDelegate::SupportedOp(registration)) {
supported_nodes.push_back(node_index);
}
}
TfLiteRegistration my_delegate_kernel_registration =
GetMyDelegateNodeRegistration();
// This call split the graphs into subgraphs, for subgraphs that can be
// handled by the delegate, it will replace it with a
// 'my_delegate_kernel_registration'
TfLiteIntArray* supported_nodes_int_array =
::tflite::ConvertVectorToTfLiteIntArray(supported_nodes);
auto status = context->ReplaceNodeSubsetsWithDelegateKernels(
context, my_delegate_kernel_registration,
supported_nodes_int_array, delegate);
TfLiteIntArrayFree(supported_nodes_int_array);
return status
}
void FreeBufferHandle(TfLiteContext* context, TfLiteDelegate* delegate,
TfLiteBufferHandle* handle) {
// Do any cleanups.
}
TfLiteStatus CopyToBufferHandle(TfLiteContext* context,
TfLiteDelegate* delegate,
TfLiteBufferHandle buffer_handle,
TfLiteTensor* tensor) {
// Copies data from tensor to delegate buffer if needed.
return kTfLiteOk;
}
TfLiteStatus CopyFromBufferHandle(TfLiteContext* context,
TfLiteDelegate* delegate,
TfLiteBufferHandle buffer_handle,
TfLiteTensor* tensor) {
// Copies the data from delegate buffer into the tensor raw memory.
return kTfLiteOk;
}
// Caller takes ownership of the returned pointer.
TfLiteDelegate* CreateMyDelegate() {
TfLiteDelegate* delegate = new TfLiteDelegate;
delegate->data_ = nullptr;
delegate->flags = kTfLiteDelegateFlagsNone;
delegate->Prepare = &DelegatePrepare;
// This cannot be null.
delegate->CopyFromBufferHandle = &CopyFromBufferHandle;
// This can be null.
delegate->CopyToBufferHandle = &CopyToBufferHandle;
// This can be null.
delegate->FreeBufferHandle = &FreeBufferHandle;
return delegate;
}
// To add the delegate you need to call
auto* my_delegate = CreateMyDelegate();
if (interpreter->ModifyGraphWithDelegate(my_delegate) !=
kTfLiteOk) {
// Handle error
} else {
interpreter->Invoke();
}
...
// Don't forget to delete your delegate
delete my_delegate;
```
adb shell /data/local/tmp/benchmark_model \
--graph=/data/local/tmp/mobilenet_v1_224_quant.tflite \
--use_gpu=true
```
You can download pre-built version of this tool for Android, 64-bit ARM
architecture
[here](https://storage.googleapis.com/tensorflow-nightly-public/prod/tensorflow/release/lite/tools/nightly/latest/android_aarch64_benchmark_model.apk)
([more details](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/benchmark/android)).
### Accuracy & correctness
Delegates usually perform computations at a different precision than their CPU
counterparts. As a result, there is an (usually minor) accuracy tradeoff
associated with utilizing a delegate for hardware acceleration. Note that this
isn't *always* true; for example, since the GPU uses floating-point precision to
run quantized models, there might be a slight precision improvement (for e.g.,
<1% Top-5 improvement in ILSVRC image classification).
TensorFlow Lite has two types of tooling to measure how accurately a delegate
behaves for a given model: *Task-Based* and *Task-Agnostic*. All the tools
described in this section support the
[advanced delegation parameters](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/tools/delegates/README.md#tflite-delegate-registrar)
used by the benchmarking tool from the previous section. Note that the
sub-sections below focus on *delegate evaluation* (Does the delegate perform the
same as the CPU?) rather than model evaluation (Is the model itself good for the
task?).
#### Task-Based Evaluation
TensorFlow Lite has tools to evaluate correctness on two image-based tasks:
* [ILSVRC 2012](http://image-net.org/challenges/LSVRC/2012/) (Image
Classification) with
[top-K accuracy](https://en.wikipedia.org/wiki/Evaluation_measures_\(information_retrieval\)#Precision_at_K)
* [COCO Object Detection (w/ bounding boxes)](https://cocodataset.org/#detection-2020)
with
[mean Average Precision (mAP)](https://en.wikipedia.org/wiki/Evaluation_measures_\(information_retrieval\)#Mean_average_precision)
Prebuilt binaries of these tools (Android, 64-bit ARM architecture), along with
documentation can be found here:
* [ImageNet Image Classification](https://storage.googleapis.com/tensorflow-nightly-public/prod/tensorflow/release/lite/tools/nightly/latest/android_aarch64_eval_imagenet_image_classification)
([More details](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification))
* [COCO Object Detection](https://storage.googleapis.com/tensorflow-nightly-public/prod/tensorflow/release/lite/tools/nightly/latest/android_aarch64_eval_coco_object_detection)
([More details](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/evaluation/tasks/coco_object_detection))
The example below demonstrates
[image classification evaluation](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification)
with NNAPI utilizing Google's Edge-TPU on a Pixel 4:
```
adb shell /data/local/tmp/run_eval \
--model_file=/data/local/tmp/mobilenet_quant_v1_224.tflite \
--ground_truth_images_path=/data/local/tmp/ilsvrc_images \
--ground_truth_labels=/data/local/tmp/ilsvrc_validation_labels.txt \
--model_output_labels=/data/local/tmp/model_output_labels.txt \
--output_file_path=/data/local/tmp/accuracy_output.txt \
--num_images=0 # Run on all images. \
--use_nnapi=true \
--nnapi_accelerator_name=google-edgetpu
```
The expected output is a list of Top-K metrics from 1 to 10:
```
Top-1 Accuracy: 0.733333
Top-2 Accuracy: 0.826667
Top-3 Accuracy: 0.856667
Top-4 Accuracy: 0.87
Top-5 Accuracy: 0.89
Top-6 Accuracy: 0.903333
Top-7 Accuracy: 0.906667
Top-8 Accuracy: 0.913333
Top-9 Accuracy: 0.92
Top-10 Accuracy: 0.923333
```
#### Task-Agnostic Evaluation
For tasks where there isn't an established on-device evaluation tool, or if you
are experimenting with custom models, TensorFlow Lite has the
[Inference Diff](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/evaluation/tasks/inference_diff)
tool. (Android, 64-bit ARM binary architecture binary
[here](https://storage.googleapis.com/tensorflow-nightly-public/prod/tensorflow/release/lite/tools/nightly/latest/android_aarch64_eval_inference_diff))
Inference Diff compares TensorFlow Lite execution (in terms of latency &
output-value deviation) in two settings:
* Single-threaded CPU Inference
* User-defined Inference - defined by
[these parameters](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/tools/delegates/README.md#tflite-delegate-registrar)
To do so, the tool generates random Gaussian data and passes it through two
TFLite Interpreters - one running single-threaded CPU kernels, and the other
parametrized by the user's arguments.
It measures the latency of both, as well as the absolute difference between the
output tensors from each Interpreter, on a per-element basis.
For a model with a single output tensor, the output might look like this:
```
Num evaluation runs: 50
Reference run latency: avg=84364.2(us), std_dev=12525(us)
Test run latency: avg=7281.64(us), std_dev=2089(us)
OutputDiff[0]: avg_error=1.96277e-05, std_dev=6.95767e-06
```
What this means is that for the output tensor at index `0`, the elements from
the CPU output different from the delegate output by an average of `1.96e-05`.
Note that interpreting these numbers requires deeper knowledge of the model, and
what each output tensor signifies. If its a simple regression that determines
some sort of score or embedding, the difference should be low (otherwise it's an
error with the delegate). However, outputs like the 'detection class' one from
SSD models is a little harder to interpret. For example, it might show a
difference using this tool, but that may not mean something really wrong with
the delegate: consider two (fake) classes: "TV (ID: 10)", "Monitor (ID:20)" - If
a delegate is slightly off the golden truth and shows monitor instead of TV, the
output diff for this tensor might be something as high as 20-10 = 10.

Binary file not shown.

After

Width:  |  Height:  |  Size: 34 KiB

View File

@ -0,0 +1,171 @@
# Implementing a Delegate
Note: The API used below is experimental and is subject to change.
Follow the steps below to add a delegate:
1. Define a kernel node that is responsible for evaluating the delegate
subgraph.
1. Create an instance of
[TfLiteDelegate](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/c/common.h#L611),
which is responsible for registering the kernel node and claiming the nodes
that the delegate can execute.
To see it in code, define a delegate `MyDelegate` to execute Conv2D and Mean ops
faster.
```c++
#include "tensorflow/lite/util.h"
#include "tensorflow/lite/builtin_ops.h"
#include "tensorflow/lite/context_util.h"
// This is where the execution of the operations or whole graph happens.
// The class below has an empty implementation just as a guideline
// on the structure.
class MyDelegate {
public:
// Returns true if MyDelegate can handle this type of op.
static bool SupportedOp(const TfLiteRegistration* registration) {
switch (registration->builtin_code) {
case kTfLiteBuiltinConv2d:
case kTfLiteBuiltinMean:
return true;
default:
return false;
}
}
// Any initialization code needed
bool Init() {}
// Any preparation work needed (e.g. allocate buffers)
bool Prepare(TfLiteContext* context, TfLiteNode* node) {}
// Actual running of the delegate subgraph.
bool Invoke(TfLiteContext* context, TfLiteNode* node) {}
// ... Add any other methods needed.
};
// Create the TfLiteRegistration for the Kernel node which will replace
// the subgraph in the main TfLite graph.
TfLiteRegistration GetMyDelegateNodeRegistration() {
// This is the registration for the Delegate Node that gets added to
// the TFLite graph instead of the subgraph it replaces.
// It is treated as an OP node. But in this case
// Init initializes the delegate.
// Invoke runs the delegate graph.
// Prepare prepares the delegate.
// Free performs any memory cleanup needed by the delegate.
TfLiteRegistration kernel_registration;
kernel_registration.builtin_code = kTfLiteBuiltinDelegate;
kernel_registration.custom_name = "MyDelegate";
kernel_registration.free = [](TfLiteContext* context, void* buffer) -> void {
delete reinterpret_cast<MyDelegate*>(buffer);
};
kernel_registration.init = [](TfLiteContext* context, const char* buffer,
size_t) -> void* {
// In the node init phase, initialize MyDelegate instance
const TfLiteDelegateParams* delegate_params =
reinterpret_cast<const TfLiteDelegateParams*>(buffer);
MyDelegate* my_delegate = new MyDelegate;
if (!my_delegate->Init(context, params)) {
return nullptr;
}
return my_delegate;
};
kernel_registration.invoke = [](TfLiteContext* context,
TfLiteNode* node) -> TfLiteStatus {
MyDelegate* kernel = reinterpret_cast<MyDelegate*>(node->user_data);
return kernel->Invoke(context, node);
};
kernel_registration.prepare = [](TfLiteContext* context,
TfLiteNode* node) -> TfLiteStatus {
MyDelegate* kernel = reinterpret_cast<MyDelegate*>(node->user_data);
return kernel->Prepare(context, node);
};
return kernel_registration;
}
// TfLiteDelegate methods
TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) {
// Claim all nodes that can be evaluated by the delegate and ask the
// framework to update the graph with delegate kernel instead.
std::vector<int> supported_nodes;
TfLiteIntArray* plan;
TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan));
TfLiteNode* node;
TfLiteRegistration* registration;
for (int node_index : TfLiteIntArrayView(plan)) {
TF_LITE_ENSURE_STATUS(context->GetNodeAndRegistration(
context, node_index, &node, &registration));
if (MyDelegate::SupportedOp(registration)) {
supported_nodes.push_back(node_index);
}
}
TfLiteRegistration my_delegate_kernel_registration =
GetMyDelegateNodeRegistration();
// This call split the graphs into subgraphs, for subgraphs that can be
// handled by the delegate, it will replace it with a
// 'my_delegate_kernel_registration'
TfLiteIntArray* supported_nodes_int_array =
::tflite::ConvertVectorToTfLiteIntArray(supported_nodes);
auto status = context->ReplaceNodeSubsetsWithDelegateKernels(
context, my_delegate_kernel_registration,
supported_nodes_int_array, delegate);
TfLiteIntArrayFree(supported_nodes_int_array);
return status
}
void FreeBufferHandle(TfLiteContext* context, TfLiteDelegate* delegate,
TfLiteBufferHandle* handle) {
// Do any cleanups.
}
TfLiteStatus CopyToBufferHandle(TfLiteContext* context,
TfLiteDelegate* delegate,
TfLiteBufferHandle buffer_handle,
TfLiteTensor* tensor) {
// Copies data from tensor to delegate buffer if needed.
return kTfLiteOk;
}
TfLiteStatus CopyFromBufferHandle(TfLiteContext* context,
TfLiteDelegate* delegate,
TfLiteBufferHandle buffer_handle,
TfLiteTensor* tensor) {
// Copies the data from delegate buffer into the tensor raw memory.
return kTfLiteOk;
}
// Caller takes ownership of the returned pointer.
TfLiteDelegate* CreateMyDelegate() {
TfLiteDelegate* delegate = new TfLiteDelegate;
delegate->data_ = nullptr;
delegate->flags = kTfLiteDelegateFlagsNone;
delegate->Prepare = &DelegatePrepare;
// This cannot be null.
delegate->CopyFromBufferHandle = &CopyFromBufferHandle;
// This can be null.
delegate->CopyToBufferHandle = &CopyToBufferHandle;
// This can be null.
delegate->FreeBufferHandle = &FreeBufferHandle;
return delegate;
}
// To add the delegate you need to call
auto* my_delegate = CreateMyDelegate();
if (interpreter->ModifyGraphWithDelegate(my_delegate) !=
kTfLiteOk) {
// Handle error
} else {
interpreter->Invoke();
}
...
// Don't forget to delete your delegate
delete my_delegate;
```