1. Adapt flex delegate to the newly created SimpleDelegate APIs.
2. Change SimpleDelegate APIs accordingly as needed to adapt the flex delegate. a. Create a `version` virtual function. b. Add a new function acting as an initialization for TfLiteDelegate Prepare. c. Return a std::unique_ptr TfLiteDegate for the factory method of TfLiteDelegateFactory for less error-prone memory management. PiperOrigin-RevId: 314225150 Change-Id: I6bf7ab48f1e72f390b49dd5c6ac7ecd11d29327a
This commit is contained in:
parent
24580ebdd9
commit
edeae9fb69
|
@ -744,8 +744,9 @@ typedef struct TfLiteDelegate {
|
|||
struct TfLiteDelegate* delegate);
|
||||
|
||||
// Copy the data from delegate buffer handle into raw memory of the given
|
||||
// 'tensor'. This cannot be null. The delegate is allowed to allocate the raw
|
||||
// bytes as long as it follows the rules for kTfLiteDynamic tensors.
|
||||
// 'tensor'. Note that the delegate is allowed to allocate the raw bytes as
|
||||
// long as it follows the rules for kTfLiteDynamic tensors, in which case this
|
||||
// cannot be null.
|
||||
TfLiteStatus (*CopyFromBufferHandle)(TfLiteContext* context,
|
||||
struct TfLiteDelegate* delegate,
|
||||
TfLiteBufferHandle buffer_handle,
|
||||
|
|
|
@ -61,6 +61,7 @@ cc_library(
|
|||
deps = [
|
||||
":delegate_data",
|
||||
":delegate_only_runtime",
|
||||
"//tensorflow/lite/delegates/utils:simple_delegate",
|
||||
] + select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:portable_tensorflow_lib",
|
||||
|
@ -82,6 +83,8 @@ cc_library(
|
|||
name = "delegate_only_runtime",
|
||||
srcs = [
|
||||
"delegate.cc",
|
||||
"kernel.cc",
|
||||
"kernel.h",
|
||||
],
|
||||
hdrs = [
|
||||
"delegate.h",
|
||||
|
@ -90,14 +93,18 @@ cc_library(
|
|||
deps = [
|
||||
":buffer_map",
|
||||
":delegate_data",
|
||||
":kernel",
|
||||
":util",
|
||||
"@flatbuffers",
|
||||
"@com_google_absl//absl/strings:strings",
|
||||
"//tensorflow/lite/core/api",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite:kernel_api",
|
||||
"//tensorflow/lite:minimal_logging",
|
||||
"//tensorflow/lite:string",
|
||||
"//tensorflow/lite:string_util",
|
||||
"//tensorflow/lite:util",
|
||||
"//tensorflow/lite/delegates/utils:simple_delegate",
|
||||
"//tensorflow/lite/kernels:kernel_util",
|
||||
] + select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
|
@ -106,7 +113,12 @@ cc_library(
|
|||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
"//tensorflow/core/common_runtime/eager:execute",
|
||||
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:framework",
|
||||
],
|
||||
}),
|
||||
alwayslink = 1,
|
||||
|
@ -163,40 +175,6 @@ tf_cc_test(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "kernel",
|
||||
srcs = ["kernel.cc"],
|
||||
hdrs = ["kernel.h"],
|
||||
deps = [
|
||||
":delegate_data",
|
||||
":util",
|
||||
"@flatbuffers",
|
||||
"//tensorflow/lite/core/api",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite:kernel_api",
|
||||
"//tensorflow/lite:string",
|
||||
"//tensorflow/lite/kernels:kernel_util",
|
||||
] + select({
|
||||
# TODO(b/111881878): The android_tensorflow_lib target pulls in the full
|
||||
# set of core TensorFlow kernels. We may want to revisit this dependency
|
||||
# to allow selective registration via build targets.
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
],
|
||||
"//tensorflow:ios": [
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
"//tensorflow/core/common_runtime/eager:execute",
|
||||
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:framework",
|
||||
],
|
||||
}),
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "kernel_test",
|
||||
size = "small",
|
||||
|
@ -204,20 +182,10 @@ tf_cc_test(
|
|||
tags = ["no_gpu"], # GPU + flex is not officially supported.
|
||||
deps = [
|
||||
":delegate_data",
|
||||
":kernel",
|
||||
":delegate_only_runtime",
|
||||
":test_util",
|
||||
"@com_google_googletest//:gtest",
|
||||
] + select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:portable_tensorflow_lib",
|
||||
],
|
||||
"//tensorflow:ios": [
|
||||
"//tensorflow/core:portable_tensorflow_lib",
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:tensorflow",
|
||||
],
|
||||
}),
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
|
|
|
@ -14,6 +14,7 @@ limitations under the License.
|
|||
==============================================================================*/
|
||||
#include "tensorflow/lite/delegates/flex/delegate.h"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
|
@ -27,10 +28,32 @@ limitations under the License.
|
|||
#include "tensorflow/lite/util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace flex {
|
||||
namespace delegate {
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteDelegate* delegate) {
|
||||
// Corresponding weak declaration found in lite/interpreter_builder.cc.
|
||||
TfLiteDelegateUniquePtr AcquireFlexDelegate() {
|
||||
return tflite::FlexDelegate::Create();
|
||||
}
|
||||
|
||||
TfLiteDelegateUniquePtr FlexDelegate::Create(
|
||||
std::unique_ptr<FlexDelegate> base_delegate) {
|
||||
TFLITE_LOG_PROD_ONCE(TFLITE_LOG_INFO,
|
||||
"Created TensorFlow Lite delegate for select TF ops.");
|
||||
if (base_delegate == nullptr) {
|
||||
base_delegate.reset(new FlexDelegate());
|
||||
}
|
||||
auto flex_delegate = TfLiteDelegateFactory::Create(std::move(base_delegate));
|
||||
flex_delegate->CopyFromBufferHandle =
|
||||
[](TfLiteContext* context, TfLiteDelegate* delegate,
|
||||
TfLiteBufferHandle buffer_handle,
|
||||
TfLiteTensor* tensor) -> TfLiteStatus {
|
||||
return reinterpret_cast<FlexDelegate*>(delegate->data_)
|
||||
->CopyFromBufferHandle(context, buffer_handle, tensor);
|
||||
};
|
||||
flex_delegate->flags |= kTfLiteDelegateFlagsAllowDynamicTensors;
|
||||
return flex_delegate;
|
||||
}
|
||||
|
||||
TfLiteStatus FlexDelegate::Initialize(TfLiteContext* context) {
|
||||
// If the TensorFlow Lite thread count is explicitly configured, use it,
|
||||
// otherwise rely on the default TensorFlow threading behavior.
|
||||
tensorflow::SessionOptions session_options;
|
||||
|
@ -39,47 +62,37 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteDelegate* delegate) {
|
|||
context->recommended_num_threads);
|
||||
}
|
||||
|
||||
auto status = reinterpret_cast<DelegateData*>(delegate->data_)
|
||||
->Prepare(session_options);
|
||||
auto status = delegate_data_.Prepare(session_options);
|
||||
if (!status.ok()) {
|
||||
context->ReportError(context, "Failed to initialize TensorFlow context: %s",
|
||||
status.error_message().c_str());
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
// Get the nodes in the current execution plan. Interpreter owns this array.
|
||||
TfLiteIntArray* plan;
|
||||
TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan));
|
||||
|
||||
// Add all custom ops starting with "Flex" to list of supported nodes.
|
||||
std::vector<int> supported_nodes;
|
||||
for (int node_index : TfLiteIntArrayView(plan)) {
|
||||
TfLiteNode* node;
|
||||
TfLiteRegistration* registration;
|
||||
TF_LITE_ENSURE_STATUS(context->GetNodeAndRegistration(
|
||||
context, node_index, &node, ®istration));
|
||||
|
||||
if (IsFlexOp(registration->custom_name)) {
|
||||
supported_nodes.push_back(node_index);
|
||||
}
|
||||
}
|
||||
|
||||
// Request TFLite to partition the graph and make kernels for each independent
|
||||
// node sub set.
|
||||
TfLiteIntArray* size_and_nodes =
|
||||
ConvertVectorToTfLiteIntArray(supported_nodes);
|
||||
context->ReplaceNodeSubsetsWithDelegateKernels(context, GetKernel(),
|
||||
size_and_nodes, delegate);
|
||||
TfLiteIntArrayFree(size_and_nodes);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus CopyFromBufferHandle(TfLiteContext* context,
|
||||
TfLiteDelegate* delegate,
|
||||
TfLiteBufferHandle buffer_handle,
|
||||
TfLiteTensor* output) {
|
||||
BufferMap* buffer_map =
|
||||
reinterpret_cast<DelegateData*>(delegate->data_)->GetBufferMap(context);
|
||||
const char* FlexDelegate::name() const {
|
||||
static constexpr char kName[] = "TfLiteFlexDelegate";
|
||||
return kName;
|
||||
}
|
||||
|
||||
bool FlexDelegate::IsNodeSupportedByDelegate(
|
||||
const TfLiteRegistration* registration, const TfLiteNode* node,
|
||||
TfLiteContext* context) const {
|
||||
return IsFlexOp(registration->custom_name);
|
||||
}
|
||||
|
||||
std::unique_ptr<SimpleDelegateKernelInterface>
|
||||
FlexDelegate::CreateDelegateKernelInterface() {
|
||||
return std::unique_ptr<SimpleDelegateKernelInterface>(
|
||||
new tflite::flex::DelegateKernel());
|
||||
}
|
||||
|
||||
TfLiteStatus FlexDelegate::CopyFromBufferHandle(
|
||||
TfLiteContext* context, TfLiteBufferHandle buffer_handle,
|
||||
TfLiteTensor* output) {
|
||||
flex::BufferMap* buffer_map = delegate_data_.GetBufferMap(context);
|
||||
|
||||
if (!buffer_map->HasTensor(buffer_handle)) {
|
||||
context->ReportError(context, "Invalid tensor index %d.", buffer_handle);
|
||||
|
@ -122,31 +135,4 @@ TfLiteStatus CopyFromBufferHandle(TfLiteContext* context,
|
|||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace delegate
|
||||
} // namespace flex
|
||||
|
||||
// Corresponding weak declaration found in lite/model.cc.
|
||||
std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>
|
||||
AcquireFlexDelegate() {
|
||||
return std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>(
|
||||
tflite::FlexDelegate::Create().release(), [](TfLiteDelegate* delegate) {
|
||||
delete reinterpret_cast<tflite::FlexDelegate*>(delegate);
|
||||
});
|
||||
}
|
||||
|
||||
std::unique_ptr<FlexDelegate> FlexDelegate::Create() {
|
||||
TFLITE_LOG_PROD_ONCE(TFLITE_LOG_INFO,
|
||||
"Created TensorFlow Lite delegate for select TF ops.");
|
||||
return std::unique_ptr<FlexDelegate>(new FlexDelegate());
|
||||
}
|
||||
|
||||
FlexDelegate::FlexDelegate() : TfLiteDelegate(TfLiteDelegateCreate()) {
|
||||
data_ = &delegate_data_;
|
||||
Prepare = &flex::delegate::Prepare;
|
||||
CopyFromBufferHandle = &flex::delegate::CopyFromBufferHandle;
|
||||
flags = kTfLiteDelegateFlagsAllowDynamicTensors;
|
||||
}
|
||||
|
||||
FlexDelegate::~FlexDelegate() {}
|
||||
|
||||
} // namespace tflite
|
||||
|
|
|
@ -17,9 +17,16 @@ limitations under the License.
|
|||
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/delegates/flex/delegate_data.h"
|
||||
#include "tensorflow/lite/delegates/utils/simple_delegate.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
namespace flex {
|
||||
namespace testing {
|
||||
class KernelTest;
|
||||
} // namespace testing
|
||||
} // namespace flex
|
||||
|
||||
// WARNING: This is an experimental interface that is subject to change.
|
||||
// Delegate that can be used to extract parts of a graph that are designed to be
|
||||
// executed by TensorFlow's runtime via Eager.
|
||||
|
@ -33,22 +40,49 @@ namespace tflite {
|
|||
// ... build interpreter ...
|
||||
//
|
||||
// if (delegate) {
|
||||
// interpreter->ModifyGraphWithDelegate(delegate.get());
|
||||
// interpreter->ModifyGraphWithDelegate(std::move(delegate));
|
||||
// }
|
||||
// ... run inference ...
|
||||
// ... destroy interpreter ...
|
||||
// ... destroy delegate ...
|
||||
class FlexDelegate : public TfLiteDelegate {
|
||||
class FlexDelegate : public SimpleDelegateInterface {
|
||||
public:
|
||||
friend class flex::testing::KernelTest;
|
||||
|
||||
// Creates a delegate that supports TF ops.
|
||||
//
|
||||
// If the underyling TF Flex context creation fails, returns null.
|
||||
static std::unique_ptr<FlexDelegate> Create();
|
||||
static TfLiteDelegateUniquePtr Create() {
|
||||
return Create(/*base_delegate*/ nullptr);
|
||||
}
|
||||
|
||||
~FlexDelegate();
|
||||
~FlexDelegate() override {}
|
||||
|
||||
private:
|
||||
FlexDelegate();
|
||||
flex::DelegateData* mutable_data() { return &delegate_data_; }
|
||||
|
||||
protected:
|
||||
// We sometimes have to create certain stub data to test FlexDelegate. To
|
||||
// achieve this, we will make a testing flex delegate class that inherits from
|
||||
// FlexDelegate to override certain things for stub data creation. Therefore,
|
||||
// this function accepts a FlexDelegate instance to initiliaze it properly for
|
||||
// create a testing flex delegate in some cases, and it is only used in
|
||||
// testing.
|
||||
static TfLiteDelegateUniquePtr Create(
|
||||
std::unique_ptr<FlexDelegate> base_delegate);
|
||||
|
||||
FlexDelegate() {}
|
||||
|
||||
const char* name() const override;
|
||||
|
||||
bool IsNodeSupportedByDelegate(const TfLiteRegistration* registration,
|
||||
const TfLiteNode* node,
|
||||
TfLiteContext* context) const override;
|
||||
|
||||
TfLiteStatus Initialize(TfLiteContext* context) override;
|
||||
|
||||
std::unique_ptr<SimpleDelegateKernelInterface> CreateDelegateKernelInterface()
|
||||
override;
|
||||
|
||||
TfLiteStatus CopyFromBufferHandle(TfLiteContext* context,
|
||||
TfLiteBufferHandle buffer_handle,
|
||||
TfLiteTensor* output);
|
||||
|
||||
flex::DelegateData delegate_data_;
|
||||
};
|
||||
|
|
|
@ -26,8 +26,7 @@ using ::testing::ElementsAre;
|
|||
|
||||
class DelegateTest : public testing::FlexModelTest {
|
||||
public:
|
||||
DelegateTest() {
|
||||
delegate_ = FlexDelegate::Create();
|
||||
DelegateTest() : delegate_(FlexDelegate::Create()) {
|
||||
interpreter_.reset(new Interpreter(&error_reporter_));
|
||||
}
|
||||
|
||||
|
@ -44,7 +43,7 @@ class DelegateTest : public testing::FlexModelTest {
|
|||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<FlexDelegate> delegate_;
|
||||
std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)> delegate_;
|
||||
};
|
||||
|
||||
TEST_F(DelegateTest, FullGraph) {
|
||||
|
|
|
@ -18,6 +18,7 @@ cc_library(
|
|||
],
|
||||
deps = [
|
||||
"//tensorflow/lite/delegates/flex:delegate",
|
||||
"//tensorflow/lite/delegates/utils:simple_delegate",
|
||||
"//tensorflow/lite/java/jni",
|
||||
"//tensorflow/lite/testing:init_tensorflow",
|
||||
],
|
||||
|
|
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||
#include <jni.h>
|
||||
|
||||
#include "tensorflow/lite/delegates/flex/delegate.h"
|
||||
#include "tensorflow/lite/delegates/utils/simple_delegate.h"
|
||||
#include "tensorflow/lite/testing/init_tensorflow.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
@ -37,7 +38,8 @@ Java_org_tensorflow_lite_flex_FlexDelegate_nativeCreateDelegate(JNIEnv* env,
|
|||
JNIEXPORT void JNICALL
|
||||
Java_org_tensorflow_lite_flex_FlexDelegate_nativeDeleteDelegate(
|
||||
JNIEnv* env, jclass clazz, jlong delegate) {
|
||||
delete reinterpret_cast<tflite::FlexDelegate*>(delegate);
|
||||
tflite::TfLiteDelegateFactory::DeleteSimpleDelegate(
|
||||
reinterpret_cast<struct TfLiteDelegate*>(delegate));
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
|
|
@ -25,6 +25,7 @@ limitations under the License.
|
|||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/context_util.h"
|
||||
#include "tensorflow/lite/core/api/profiler.h"
|
||||
#include "tensorflow/lite/delegates/flex/delegate.h"
|
||||
#include "tensorflow/lite/delegates/flex/delegate_data.h"
|
||||
#include "tensorflow/lite/delegates/flex/util.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
|
@ -49,7 +50,6 @@ limitations under the License.
|
|||
|
||||
namespace tflite {
|
||||
namespace flex {
|
||||
namespace kernel {
|
||||
|
||||
struct OpNode;
|
||||
|
||||
|
@ -357,33 +357,29 @@ struct OpData {
|
|||
std::vector<int> subgraph_outputs;
|
||||
};
|
||||
|
||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
auto* op_data = new OpData;
|
||||
DelegateKernel::DelegateKernel() : op_data_(new OpData) {}
|
||||
DelegateKernel::~DelegateKernel() {}
|
||||
|
||||
const TfLiteDelegateParams* params =
|
||||
reinterpret_cast<const TfLiteDelegateParams*>(buffer);
|
||||
CHECK(params);
|
||||
CHECK(params->delegate);
|
||||
CHECK(params->delegate->data_);
|
||||
op_data->eager_context =
|
||||
reinterpret_cast<DelegateData*>(params->delegate->data_)
|
||||
->GetEagerContext();
|
||||
op_data->buffer_map = reinterpret_cast<DelegateData*>(params->delegate->data_)
|
||||
->GetBufferMap(context);
|
||||
TfLiteStatus DelegateKernel::Init(TfLiteContext* context,
|
||||
const TfLiteDelegateParams* params) {
|
||||
auto* flex_delegate_data =
|
||||
reinterpret_cast<FlexDelegate*>(params->delegate->data_)->mutable_data();
|
||||
op_data_->eager_context = flex_delegate_data->GetEagerContext();
|
||||
op_data_->buffer_map = flex_delegate_data->GetBufferMap(context);
|
||||
|
||||
CHECK(params->output_tensors);
|
||||
std::set<int> output_set;
|
||||
for (auto tensor_index : TfLiteIntArrayView(params->output_tensors)) {
|
||||
op_data->subgraph_outputs.push_back(tensor_index);
|
||||
op_data_->subgraph_outputs.push_back(tensor_index);
|
||||
output_set.insert(tensor_index);
|
||||
}
|
||||
|
||||
CHECK(params->input_tensors);
|
||||
for (auto tensor_index : TfLiteIntArrayView(params->input_tensors)) {
|
||||
op_data->subgraph_inputs.push_back(tensor_index);
|
||||
op_data_->subgraph_inputs.push_back(tensor_index);
|
||||
}
|
||||
|
||||
op_data->nodes.reserve(params->nodes_to_replace->size);
|
||||
op_data_->nodes.reserve(params->nodes_to_replace->size);
|
||||
|
||||
CHECK(params->nodes_to_replace);
|
||||
tensorflow::Status status;
|
||||
|
@ -392,8 +388,8 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
|||
TfLiteRegistration* reg;
|
||||
context->GetNodeAndRegistration(context, node_index, &node, ®);
|
||||
|
||||
op_data->nodes.emplace_back(new OpNode(node->inputs, node->outputs));
|
||||
OpNode& node_data = *op_data->nodes.back();
|
||||
op_data_->nodes.emplace_back(new OpNode(node->inputs, node->outputs));
|
||||
OpNode& node_data = *op_data_->nodes.back();
|
||||
|
||||
node_data.set_index(node_index);
|
||||
node_data.set_name("");
|
||||
|
@ -401,16 +397,11 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
|||
status = node_data.InitializeNodeDef(node->custom_initial_data,
|
||||
node->custom_initial_data_size);
|
||||
if (!status.ok()) break;
|
||||
status = node_data.BuildEagerOp(op_data->eager_context);
|
||||
status = node_data.BuildEagerOp(op_data_->eager_context);
|
||||
if (!status.ok()) break;
|
||||
}
|
||||
|
||||
if (ConvertStatus(context, status) != kTfLiteOk) {
|
||||
// We can't return an error from this function but ConvertStatus will
|
||||
// report them and we will stop processing in Prepare() if anything went
|
||||
// wrong.
|
||||
return op_data;
|
||||
}
|
||||
TF_LITE_ENSURE_STATUS(ConvertStatus(context, status));
|
||||
|
||||
// Given a TfLite tensor index, return the OpNode that produces it,
|
||||
// along with it index into that OpNodes list of outputs.
|
||||
|
@ -418,7 +409,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
|||
|
||||
// Find out how each tensor is produced. This does not account for
|
||||
// tensors that are not produce by eager ops.
|
||||
for (auto& node_data : op_data->nodes) {
|
||||
for (auto& node_data : op_data_->nodes) {
|
||||
node_data->mutable_outputs()->InitializeGraphOutputs(output_set);
|
||||
for (int i = 0; i < node_data->outputs().Size(); ++i) {
|
||||
int output_index = node_data->outputs().TfLiteIndex(i);
|
||||
|
@ -428,21 +419,15 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
|||
|
||||
// For each node, resolve the inputs, so we can keep pointers to the nodes
|
||||
// that produces them.
|
||||
for (auto& node_data : op_data->nodes) {
|
||||
for (auto& node_data : op_data_->nodes) {
|
||||
node_data->mutable_inputs()->InitializeTensorSources(tflite_tensor_sources);
|
||||
}
|
||||
|
||||
return op_data;
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
void Free(TfLiteContext* context, void* buffer) {
|
||||
delete reinterpret_cast<OpData*>(buffer);
|
||||
}
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
const auto* op_data = reinterpret_cast<OpData*>(node->user_data);
|
||||
TfLiteStatus DelegateKernel::Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_MSG(
|
||||
context, op_data->eager_context != nullptr,
|
||||
context, op_data_->eager_context != nullptr,
|
||||
"Failed to initialize eager context. This often happens when a CPU "
|
||||
"device has not been registered, presumably because some symbols from "
|
||||
"tensorflow/core:core_cpu_impl were not linked into the binary.");
|
||||
|
@ -452,8 +437,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||
std::map<int, int> tensor_ref_count;
|
||||
|
||||
// Whenever we find a constant tensor, insert it in the buffer map.
|
||||
BufferMap* buffer_map = op_data->buffer_map;
|
||||
for (auto tensor_index : op_data->subgraph_inputs) {
|
||||
BufferMap* buffer_map = op_data_->buffer_map;
|
||||
for (auto tensor_index : op_data_->subgraph_inputs) {
|
||||
TfLiteTensor* tensor = &context->tensors[tensor_index];
|
||||
if (IsConstantTensor(tensor)) {
|
||||
if (!buffer_map->HasTensor(tensor_index)) {
|
||||
|
@ -469,12 +454,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||
|
||||
// All output tensors are allocated by TensorFlow/Eager, so we
|
||||
// mark them as kTfLiteDynamic.
|
||||
for (auto tensor_index : op_data->subgraph_outputs) {
|
||||
for (auto tensor_index : op_data_->subgraph_outputs) {
|
||||
SetTensorToDynamic(&context->tensors[tensor_index]);
|
||||
++tensor_ref_count[tensor_index];
|
||||
}
|
||||
|
||||
for (const auto& node_data : op_data->nodes) {
|
||||
for (const auto& node_data : op_data_->nodes) {
|
||||
if (node_data->nodedef().op().empty()) {
|
||||
context->ReportError(context, "Invalid NodeDef in Flex op '%s'",
|
||||
node_data->name().c_str());
|
||||
|
@ -490,7 +475,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||
// All tensors that are referenced exactly once are marked as "forwardable",
|
||||
// meaning that we will allow TensorFlow to reuse its buffer as the output of
|
||||
// an op.
|
||||
for (auto& node_data : op_data->nodes) {
|
||||
for (auto& node_data : op_data_->nodes) {
|
||||
for (int i = 0; i < node_data->inputs().Size(); ++i) {
|
||||
bool f = (tensor_ref_count[node_data->inputs().TfLiteIndex(i)] == 1);
|
||||
node_data->mutable_inputs()->SetForwardable(i, f);
|
||||
|
@ -500,13 +485,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
auto* op_data = reinterpret_cast<OpData*>(node->user_data);
|
||||
BufferMap* buffer_map = op_data->buffer_map;
|
||||
TfLiteStatus DelegateKernel::Invoke(TfLiteContext* context, TfLiteNode* node) {
|
||||
BufferMap* buffer_map = op_data_->buffer_map;
|
||||
|
||||
// Insert a tensor in the buffer map for all inputs that are not constant.
|
||||
// Constants were handled in Prepare() already.
|
||||
for (auto tensor_index : op_data->subgraph_inputs) {
|
||||
for (auto tensor_index : op_data_->subgraph_inputs) {
|
||||
TfLiteTensor* tensor = &context->tensors[tensor_index];
|
||||
if (!IsConstantTensor(tensor)) {
|
||||
// If this tensor is part of an earlier TF subgraph we should not add it
|
||||
|
@ -519,7 +503,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||
}
|
||||
|
||||
// Execute the TensorFlow Ops sequentially.
|
||||
for (auto& node_data : op_data->nodes) {
|
||||
for (auto& node_data : op_data_->nodes) {
|
||||
TFLITE_SCOPED_DELEGATE_OPERATOR_PROFILE(
|
||||
reinterpret_cast<Profiler*>(context->profiler),
|
||||
node_data->name().c_str(), node_data->index());
|
||||
|
@ -528,7 +512,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||
TF_LITE_ENSURE_OK(context, ConvertStatus(context, status));
|
||||
}
|
||||
|
||||
for (auto tensor_index : op_data->subgraph_outputs) {
|
||||
for (auto tensor_index : op_data_->subgraph_outputs) {
|
||||
if (!buffer_map->HasTensor(tensor_index)) {
|
||||
context->ReportError(context, "Cannot write to invalid tensor index %d",
|
||||
tensor_index);
|
||||
|
@ -546,21 +530,5 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace kernel
|
||||
|
||||
TfLiteRegistration GetKernel() {
|
||||
TfLiteRegistration registration{
|
||||
&kernel::Init,
|
||||
&kernel::Free,
|
||||
&kernel::Prepare,
|
||||
&kernel::Eval,
|
||||
nullptr, // .profiling_string
|
||||
kTfLiteBuiltinDelegate, // .builtin_code
|
||||
"TfLiteFlexDelegate", // .custom_name
|
||||
1, // .version
|
||||
};
|
||||
return registration;
|
||||
}
|
||||
|
||||
} // namespace flex
|
||||
} // namespace tflite
|
||||
|
|
|
@ -15,18 +15,28 @@ limitations under the License.
|
|||
#ifndef TENSORFLOW_LITE_DELEGATES_FLEX_KERNEL_H_
|
||||
#define TENSORFLOW_LITE_DELEGATES_FLEX_KERNEL_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/delegates/utils/simple_delegate.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace flex {
|
||||
|
||||
// Return the registration object used to initialize and execute ops that will
|
||||
// be delegated to TensorFlow's Eager runtime. This TF Lite op is created by
|
||||
// the flex delegate to handle execution of a supported subgraph. The usual
|
||||
// flow is that the delegate informs the interpreter of supported nodes in a
|
||||
// graph, and each supported subgraph is replaced with one instance of this
|
||||
// kernel.
|
||||
TfLiteRegistration GetKernel();
|
||||
struct OpData;
|
||||
class DelegateKernel : public SimpleDelegateKernelInterface {
|
||||
public:
|
||||
DelegateKernel();
|
||||
~DelegateKernel() override;
|
||||
|
||||
TfLiteStatus Init(TfLiteContext* context,
|
||||
const TfLiteDelegateParams* params) override;
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) override;
|
||||
TfLiteStatus Invoke(TfLiteContext* context, TfLiteNode* node) override;
|
||||
|
||||
private:
|
||||
std::unique_ptr<OpData> op_data_;
|
||||
};
|
||||
|
||||
} // namespace flex
|
||||
} // namespace tflite
|
||||
|
|
|
@ -12,38 +12,30 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/delegates/flex/kernel.h"
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/delegates/flex/delegate.h"
|
||||
#include "tensorflow/lite/delegates/flex/delegate_data.h"
|
||||
#include "tensorflow/lite/delegates/flex/test_util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace flex {
|
||||
namespace {
|
||||
namespace testing {
|
||||
|
||||
using ::testing::ContainsRegex;
|
||||
using ::testing::ElementsAre;
|
||||
using ::testing::ElementsAreArray;
|
||||
|
||||
TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteDelegate* delegate,
|
||||
const std::vector<int>& supported_nodes) {
|
||||
TfLiteIntArray* size_and_nodes =
|
||||
ConvertVectorToTfLiteIntArray(supported_nodes);
|
||||
TF_LITE_ENSURE_STATUS(context->ReplaceNodeSubsetsWithDelegateKernels(
|
||||
context, flex::GetKernel(), size_and_nodes, delegate));
|
||||
TfLiteIntArrayFree(size_and_nodes);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
// There is no easy way to pass a parameter into the TfLiteDelegate's
|
||||
// 'prepare' function, so we keep a global map for testing purposed.
|
||||
// To avoid collisions use: GetPrepareFunction<__LINE__>().
|
||||
std::map<int, std::vector<int>>* GetGlobalOpLists() {
|
||||
static auto* op_list = new std::map<int, std::vector<int>>;
|
||||
return op_list;
|
||||
}
|
||||
// A testing flex delegate that supports every node regardless whether it's
|
||||
// actually supported or not. It's only for testing certain scenarios.
|
||||
class TestFlexDelegate : public FlexDelegate {
|
||||
protected:
|
||||
bool IsNodeSupportedByDelegate(const TfLiteRegistration* registration,
|
||||
const TfLiteNode* node,
|
||||
TfLiteContext* context) const override {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
class KernelTest : public testing::FlexModelTest {
|
||||
public:
|
||||
|
@ -51,51 +43,16 @@ class KernelTest : public testing::FlexModelTest {
|
|||
static constexpr int kTwos = 2; // This is the index of a tensor of 2's.
|
||||
static constexpr int kMaxTensors = 30;
|
||||
|
||||
static void SetUpTestSuite() { GetGlobalOpLists()->clear(); }
|
||||
KernelTest() { interpreter_.reset(new Interpreter(&error_reporter_)); }
|
||||
|
||||
KernelTest() {
|
||||
CHECK(delegate_data_.Prepare(tensorflow::SessionOptions{}).ok());
|
||||
interpreter_.reset(new Interpreter(&error_reporter_));
|
||||
void ApplyFlexDelegate(std::unique_ptr<FlexDelegate> delegate = nullptr) {
|
||||
auto flex_delegate = FlexDelegate::Create(std::move(delegate));
|
||||
auto* delegate_data =
|
||||
reinterpret_cast<FlexDelegate*>(flex_delegate->data_)->mutable_data();
|
||||
CHECK(delegate_data->Prepare(tensorflow::SessionOptions{}).ok());
|
||||
CHECK(interpreter_->ModifyGraphWithDelegate(std::move(flex_delegate)) ==
|
||||
kTfLiteOk);
|
||||
}
|
||||
|
||||
typedef TfLiteStatus (*PrepareFunction)(TfLiteContext* context,
|
||||
TfLiteDelegate* delegate);
|
||||
|
||||
template <int KEY>
|
||||
PrepareFunction GetPrepareFunction() {
|
||||
GetGlobalOpLists()->insert({KEY, tf_ops_});
|
||||
return [](TfLiteContext* context, TfLiteDelegate* delegate) {
|
||||
return GenericPrepare(context, delegate, GetGlobalOpLists()->at(KEY));
|
||||
};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ConfigureDelegate(T prepare_function) {
|
||||
delegate_.data_ = &delegate_data_;
|
||||
delegate_.flags = kTfLiteDelegateFlagsAllowDynamicTensors;
|
||||
delegate_.FreeBufferHandle = nullptr;
|
||||
delegate_.Prepare = prepare_function;
|
||||
delegate_.CopyFromBufferHandle = [](TfLiteContext* context,
|
||||
TfLiteDelegate* delegate,
|
||||
TfLiteBufferHandle buffer_handle,
|
||||
TfLiteTensor* output) {
|
||||
auto* delegate_data = reinterpret_cast<DelegateData*>(delegate->data_);
|
||||
auto* buffer_map = delegate_data->GetBufferMap(context);
|
||||
if (!buffer_map->HasTensor(buffer_handle)) {
|
||||
context->ReportError(context, "Tensor '%d' not found", buffer_handle);
|
||||
return kTfLiteError;
|
||||
}
|
||||
tensorflow::StringPiece values =
|
||||
buffer_map->GetTensor(buffer_handle).tensor_data();
|
||||
memcpy(output->data.raw, values.data(), values.size());
|
||||
return kTfLiteOk;
|
||||
};
|
||||
CHECK(interpreter_->ModifyGraphWithDelegate(&delegate_) == kTfLiteOk);
|
||||
}
|
||||
|
||||
private:
|
||||
DelegateData delegate_data_;
|
||||
TfLiteDelegate delegate_;
|
||||
};
|
||||
|
||||
TEST_F(KernelTest, FullGraph) {
|
||||
|
@ -108,10 +65,7 @@ TEST_F(KernelTest, FullGraph) {
|
|||
AddTfOp(testing::kAdd, {2, 5}, {7});
|
||||
AddTfOp(testing::kMul, {6, 7}, {8});
|
||||
|
||||
// Apply Delegate.
|
||||
ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) {
|
||||
return GenericPrepare(context, delegate, {0, 1, 2, 3, 4});
|
||||
});
|
||||
ApplyFlexDelegate();
|
||||
|
||||
// Define inputs.
|
||||
SetShape(0, {2, 2, 1});
|
||||
|
@ -140,9 +94,7 @@ TEST_F(KernelTest, BadTensorFlowOp) {
|
|||
AddTensors(2, {0}, {1}, kTfLiteFloat32, {3});
|
||||
AddTfOp(testing::kNonExistent, {0}, {1});
|
||||
|
||||
ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) {
|
||||
return GenericPrepare(context, delegate, {0});
|
||||
});
|
||||
ApplyFlexDelegate(std::unique_ptr<FlexDelegate>(new TestFlexDelegate()));
|
||||
|
||||
ASSERT_NE(interpreter_->AllocateTensors(), kTfLiteOk);
|
||||
ASSERT_THAT(error_reporter().error_messages(),
|
||||
|
@ -153,9 +105,7 @@ TEST_F(KernelTest, BadNumberOfOutputs) {
|
|||
AddTensors(3, {0}, {1, 2}, kTfLiteFloat32, {3});
|
||||
AddTfOp(testing::kIdentity, {0}, {1, 2});
|
||||
|
||||
ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) {
|
||||
return GenericPrepare(context, delegate, {0});
|
||||
});
|
||||
ApplyFlexDelegate();
|
||||
|
||||
SetShape(0, {2, 2, 1});
|
||||
SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
|
||||
|
@ -171,9 +121,7 @@ TEST_F(KernelTest, IncompatibleNodeDef) {
|
|||
// Cast is a TF op, but we don't add the proper nodedef to it in AddTfOp.
|
||||
AddTfOp(testing::kIncompatibleNodeDef, {0}, {1});
|
||||
|
||||
ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) {
|
||||
return GenericPrepare(context, delegate, {0});
|
||||
});
|
||||
ApplyFlexDelegate();
|
||||
|
||||
SetShape(0, {2, 2, 1});
|
||||
SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
|
||||
|
@ -188,14 +136,14 @@ TEST_F(KernelTest, WrongSetOfNodes) {
|
|||
AddTfOp(testing::kUnpack, {0}, {1, 2});
|
||||
AddTfLiteMulOp({1, 2}, {3});
|
||||
|
||||
// Specify that testing::kMul (#1) is supported when it actually isn't.
|
||||
ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) {
|
||||
return GenericPrepare(context, delegate, {0, 1});
|
||||
});
|
||||
// Specify that testing::kMul (#1) is supported when it actually isn't so that
|
||||
// we choose to use the TestFlexDelegate that supports every node regardless
|
||||
// whether it's actually supported or not.
|
||||
ApplyFlexDelegate(std::unique_ptr<FlexDelegate>(new TestFlexDelegate()));
|
||||
|
||||
ASSERT_NE(interpreter_->AllocateTensors(), kTfLiteOk);
|
||||
ASSERT_THAT(error_reporter().error_messages(),
|
||||
ContainsRegex("Invalid NodeDef in Flex op"));
|
||||
ContainsRegex("Cannot convert empty data into a valid NodeDef"));
|
||||
}
|
||||
|
||||
TEST_F(KernelTest, MixedGraph) {
|
||||
|
@ -207,9 +155,7 @@ TEST_F(KernelTest, MixedGraph) {
|
|||
AddTfOp(testing::kAdd, {2, 5}, {7});
|
||||
AddTfLiteMulOp({6, 7}, {8});
|
||||
|
||||
ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) {
|
||||
return GenericPrepare(context, delegate, {0, 1, 2, 3});
|
||||
});
|
||||
ApplyFlexDelegate();
|
||||
|
||||
SetShape(0, {2, 2, 1});
|
||||
SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
|
||||
|
@ -251,14 +197,7 @@ TEST_F(KernelTest, SplitGraph) {
|
|||
// The two branches added together:
|
||||
AddTfOp(testing::kAdd, {9, 16}, {17}); // => 16
|
||||
|
||||
ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) {
|
||||
// All ops but #3 are TF ops, handled by the delegate. However, because #4
|
||||
// depends on the non-TF op, two subgraphs are necessary:
|
||||
// TF subgraph 1: 0, 1, 2, 6, 7, 8, 9
|
||||
// TF Lite Op: 3
|
||||
// TF subgraph 2: 4, 5, 10
|
||||
return GenericPrepare(context, delegate, {0, 1, 2, 4, 5, 6, 7, 8, 9, 10});
|
||||
});
|
||||
ApplyFlexDelegate();
|
||||
|
||||
SetShape(0, {2, 2, 2, 1});
|
||||
SetValues(0, a);
|
||||
|
@ -291,9 +230,8 @@ class MultipleSubgraphsTest : public KernelTest {
|
|||
public:
|
||||
static constexpr int kInput = 0;
|
||||
|
||||
void PrepareInterpreter(PrepareFunction prepare,
|
||||
const std::vector<float>& input) {
|
||||
ConfigureDelegate(prepare);
|
||||
void PrepareInterpreter(const std::vector<float>& input) {
|
||||
ApplyFlexDelegate();
|
||||
|
||||
SetShape(kOnes, {3});
|
||||
SetValues(kOnes, {1.0f, 1.0f, 1.0f});
|
||||
|
@ -336,7 +274,7 @@ TEST_F(MultipleSubgraphsTest, ForwardabilityIsLocal) {
|
|||
AddTfLiteMulOp({10, 7}, {12});
|
||||
|
||||
auto input = {3.0f, 4.0f, 5.0f};
|
||||
PrepareInterpreter(GetPrepareFunction<__LINE__>(), input);
|
||||
PrepareInterpreter(input);
|
||||
|
||||
ASSERT_TRUE(Invoke());
|
||||
ASSERT_THAT(GetValues(12), ElementsAreArray(Apply(input, [](float in) {
|
||||
|
@ -371,7 +309,7 @@ TEST_F(MultipleSubgraphsTest, DoNotRemoveInputTensors) {
|
|||
AddTfLiteMulOp({10, 7}, {12});
|
||||
|
||||
auto input = {3.0f, 4.0f, 5.0f};
|
||||
PrepareInterpreter(GetPrepareFunction<__LINE__>(), input);
|
||||
PrepareInterpreter(input);
|
||||
|
||||
ASSERT_TRUE(Invoke());
|
||||
ASSERT_THAT(GetValues(12), ElementsAreArray(Apply(input, [](float in) {
|
||||
|
@ -405,7 +343,7 @@ TEST_F(MultipleSubgraphsTest, DoNotForwardInputTensors) {
|
|||
AddTfLiteMulOp({10, 7}, {12});
|
||||
|
||||
auto input = {3.0f, 4.0f, 5.0f};
|
||||
PrepareInterpreter(GetPrepareFunction<__LINE__>(), input);
|
||||
PrepareInterpreter(input);
|
||||
|
||||
ASSERT_TRUE(Invoke());
|
||||
ASSERT_THAT(GetValues(12), ElementsAreArray(Apply(input, [](float in) {
|
||||
|
@ -413,7 +351,7 @@ TEST_F(MultipleSubgraphsTest, DoNotForwardInputTensors) {
|
|||
})));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace testing
|
||||
} // namespace flex
|
||||
} // namespace tflite
|
||||
|
||||
|
|
|
@ -32,6 +32,7 @@ TfLiteRegistration GetDelegateKernelRegistration(
|
|||
kernel_registration.profiling_string = nullptr;
|
||||
kernel_registration.builtin_code = kTfLiteBuiltinDelegate;
|
||||
kernel_registration.custom_name = delegate->name();
|
||||
kernel_registration.version = 1;
|
||||
kernel_registration.free = [](TfLiteContext* context, void* buffer) -> void {
|
||||
delete reinterpret_cast<SimpleDelegateKernelInterface*>(buffer);
|
||||
};
|
||||
|
@ -77,6 +78,7 @@ TfLiteStatus DelegatePrepare(TfLiteContext* context,
|
|||
TfLiteDelegate* base_delegate) {
|
||||
auto* delegate =
|
||||
reinterpret_cast<SimpleDelegateInterface*>(base_delegate->data_);
|
||||
TF_LITE_ENSURE_STATUS(delegate->Initialize(context));
|
||||
delegates::IsNodeSupportedFn node_supported_fn =
|
||||
[=](TfLiteContext* context, TfLiteNode* node,
|
||||
TfLiteRegistration* registration,
|
||||
|
@ -125,5 +127,4 @@ void TfLiteDelegateFactory::DeleteSimpleDelegate(TfLiteDelegate* delegate) {
|
|||
delete simple_delegate;
|
||||
delete delegate;
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
|
|
@ -20,8 +20,12 @@ limitations under the License.
|
|||
// this interface to build/prepare/invoke the delegated subgraph.
|
||||
// - SimpleDelegateInterface:
|
||||
// This class wraps TFLiteDelegate and users need to implement the interface and
|
||||
// then Call GetFinalizedDelegate() to get TfLiteDelegate* that can be passed to
|
||||
// ModifyGraphWithDelegate.
|
||||
// then call TfLiteDelegateFactory::CreateSimpleDelegate(...) to get
|
||||
// TfLiteDelegate* that can be passed to ModifyGraphWithDelegate and free it via
|
||||
// TfLiteDelegateFactory::DeleteSimpleDelegate(...).
|
||||
// or call TfLiteDelegateFactory::Create(...) to get a std::unique_ptr
|
||||
// TfLiteDelegate that can also be passed to ModifyGraphWithDelegate, in which
|
||||
// case TfLite interpereter takes the memory ownership of the delegate.
|
||||
#ifndef TENSORFLOW_LITE_DELEGATES_UTILS_SIMPLE_DELEGATE_H_
|
||||
#define TENSORFLOW_LITE_DELEGATES_UTILS_SIMPLE_DELEGATE_H_
|
||||
|
||||
|
@ -31,6 +35,9 @@ limitations under the License.
|
|||
|
||||
namespace tflite {
|
||||
|
||||
using TfLiteDelegateUniquePtr =
|
||||
std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>;
|
||||
|
||||
// Users should inherit from this class and implement the interface below.
|
||||
// Each instance represents a single part of the graph (subgraph).
|
||||
class SimpleDelegateKernelInterface {
|
||||
|
@ -49,6 +56,7 @@ class SimpleDelegateKernelInterface {
|
|||
|
||||
// Actual subgraph inference should happen on this call.
|
||||
// Returns status, and signalling any errors.
|
||||
// TODO(b/157882025): change this to Eval to be consistent w/ a TFLite kernel.
|
||||
virtual TfLiteStatus Invoke(TfLiteContext* context, TfLiteNode* node) = 0;
|
||||
};
|
||||
|
||||
|
@ -58,6 +66,7 @@ class SimpleDelegateKernelInterface {
|
|||
//
|
||||
// Clients should implement the following methods:
|
||||
// - IsNodeSupportedByDelegate
|
||||
// - Initialize
|
||||
// - name
|
||||
// - CreateDelegateKernelInterface
|
||||
class SimpleDelegateInterface {
|
||||
|
@ -71,8 +80,14 @@ class SimpleDelegateInterface {
|
|||
const TfLiteNode* node,
|
||||
TfLiteContext* context) const = 0;
|
||||
|
||||
// Initialize the delegate before finding and replacing TfLite nodes with
|
||||
// delegate kernels, for example, retrieving some TFLite settings from
|
||||
// 'context'.
|
||||
virtual TfLiteStatus Initialize(TfLiteContext* context) = 0;
|
||||
|
||||
// Returns a name that identifies the delegate.
|
||||
// This name is used for debugging/logging/profiling.
|
||||
// TODO(b/157882025): change this to Name()
|
||||
virtual const char* name() const = 0;
|
||||
|
||||
// Returns instance of an object that implements the interface
|
||||
|
@ -84,13 +99,8 @@ class SimpleDelegateInterface {
|
|||
CreateDelegateKernelInterface() = 0;
|
||||
};
|
||||
|
||||
// Factory class that provides two static methods
|
||||
// CreateSimpleDelegate
|
||||
// DeleteSimpleDelegate
|
||||
// Which should be used to construct TfLiteDelegate from
|
||||
// Simple Delegate and delete TfLiteDelegate and SimpleDelegate give
|
||||
// tfLiteDelegate* created from 'CreateSimpleDelegate' method.
|
||||
// Users should use these methods to Create and Destroy the delegate.
|
||||
// Factory class that provides static methods to deal with SimpleDelegate
|
||||
// creation and deletion.
|
||||
class TfLiteDelegateFactory {
|
||||
public:
|
||||
// Creates TfLiteDelegate from the provided SimpleDelegateInterface.
|
||||
|
@ -99,9 +109,17 @@ class TfLiteDelegateFactory {
|
|||
std::unique_ptr<SimpleDelegateInterface> simple_delegate);
|
||||
|
||||
// Deletes 'delegate' the passed pointer must be the one returned
|
||||
// from GetFinalizedDelegate.
|
||||
// from CreateSimpleDelegate.
|
||||
// This function will destruct the SimpleDelegate object too.
|
||||
static void DeleteSimpleDelegate(TfLiteDelegate* delegate);
|
||||
|
||||
// A convenient function wrapping the above two functions and returning a
|
||||
// std::unique_ptr type for auto memory management.
|
||||
inline static TfLiteDelegateUniquePtr Create(
|
||||
std::unique_ptr<SimpleDelegateInterface> simple_delegate) {
|
||||
return TfLiteDelegateUniquePtr(
|
||||
CreateSimpleDelegate(std::move(simple_delegate)), DeleteSimpleDelegate);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tflite
|
||||
|
|
|
@ -72,7 +72,12 @@ class TestSimpleDelegate : public SimpleDelegateInterface {
|
|||
return options_.allowed_builtin_code == registration->builtin_code;
|
||||
}
|
||||
|
||||
const char* name() const override { return "TestSimpleDelegate"; }
|
||||
TfLiteStatus Initialize(TfLiteContext* context) override { return kTfLiteOk; }
|
||||
|
||||
const char* name() const override {
|
||||
static constexpr char kName[] = "TestSimpleDelegate";
|
||||
return kName;
|
||||
}
|
||||
|
||||
std::unique_ptr<SimpleDelegateKernelInterface> CreateDelegateKernelInterface()
|
||||
override {
|
||||
|
@ -113,27 +118,24 @@ class TestDelegate : public ::testing::Test {
|
|||
reg);
|
||||
}
|
||||
|
||||
void TearDown() override {
|
||||
interpreter_.reset();
|
||||
TfLiteDelegateFactory::DeleteSimpleDelegate(delegate_);
|
||||
}
|
||||
void TearDown() override { interpreter_.reset(); }
|
||||
|
||||
protected:
|
||||
std::unique_ptr<Interpreter> interpreter_;
|
||||
TfLiteDelegate* delegate_ = nullptr;
|
||||
};
|
||||
|
||||
TEST_F(TestDelegate, BasicDelegate) {
|
||||
TestSimpleDelegateOptions options;
|
||||
options.allowed_builtin_code = kTfLiteBuiltinAdd;
|
||||
delegate_ = TfLiteDelegateFactory::CreateSimpleDelegate(
|
||||
auto delegate = TfLiteDelegateFactory::Create(
|
||||
std::make_unique<TestSimpleDelegate>(options));
|
||||
interpreter_->ModifyGraphWithDelegate(delegate_);
|
||||
interpreter_->ModifyGraphWithDelegate(std::move(delegate));
|
||||
|
||||
ASSERT_EQ(interpreter_->execution_plan().size(), 1);
|
||||
int node = interpreter_->execution_plan()[0];
|
||||
const auto* node_and_reg = interpreter_->node_and_registration(node);
|
||||
EXPECT_EQ("TestSimpleDelegate", node_and_reg->second.custom_name);
|
||||
EXPECT_STREQ("TestSimpleDelegate", node_and_reg->second.custom_name);
|
||||
EXPECT_EQ(1, node_and_reg->second.version);
|
||||
|
||||
const TfLiteDelegateParams* params = static_cast<const TfLiteDelegateParams*>(
|
||||
node_and_reg->first.builtin_data);
|
||||
|
@ -154,9 +156,9 @@ TEST_F(TestDelegate, BasicDelegate) {
|
|||
TEST_F(TestDelegate, NoNodesToDelegate) {
|
||||
TestSimpleDelegateOptions options;
|
||||
options.allowed_builtin_code = kTfLiteBuiltinSub;
|
||||
delegate_ = TfLiteDelegateFactory::CreateSimpleDelegate(
|
||||
auto delegate = TfLiteDelegateFactory::Create(
|
||||
std::make_unique<TestSimpleDelegate>(options));
|
||||
interpreter_->ModifyGraphWithDelegate(delegate_);
|
||||
interpreter_->ModifyGraphWithDelegate(std::move(delegate));
|
||||
|
||||
ASSERT_EQ(interpreter_->execution_plan().size(), 3);
|
||||
}
|
||||
|
@ -165,19 +167,20 @@ TEST_F(TestDelegate, DelegateFailedPrepare) {
|
|||
TestSimpleDelegateOptions options;
|
||||
options.allowed_builtin_code = kTfLiteBuiltinAdd;
|
||||
options.error_during_prepare = true;
|
||||
delegate_ = TfLiteDelegateFactory::CreateSimpleDelegate(
|
||||
auto delegate = TfLiteDelegateFactory::Create(
|
||||
std::make_unique<TestSimpleDelegate>(options));
|
||||
ASSERT_EQ(kTfLiteDelegateError,
|
||||
interpreter_->ModifyGraphWithDelegate(delegate_));
|
||||
interpreter_->ModifyGraphWithDelegate(std::move(delegate)));
|
||||
}
|
||||
|
||||
TEST_F(TestDelegate, DelegateFailedInvoke) {
|
||||
TestSimpleDelegateOptions options;
|
||||
options.allowed_builtin_code = kTfLiteBuiltinAdd;
|
||||
options.error_during_invoke = true;
|
||||
delegate_ = TfLiteDelegateFactory::CreateSimpleDelegate(
|
||||
auto delegate = TfLiteDelegateFactory::Create(
|
||||
std::make_unique<TestSimpleDelegate>(options));
|
||||
ASSERT_EQ(kTfLiteOk, interpreter_->ModifyGraphWithDelegate(delegate_));
|
||||
ASSERT_EQ(kTfLiteOk,
|
||||
interpreter_->ModifyGraphWithDelegate(std::move(delegate)));
|
||||
ASSERT_EQ(kTfLiteError, interpreter_->Invoke());
|
||||
}
|
||||
|
||||
|
@ -185,10 +188,10 @@ TEST_F(TestDelegate, DelegateFailedInit) {
|
|||
TestSimpleDelegateOptions options;
|
||||
options.allowed_builtin_code = kTfLiteBuiltinAdd;
|
||||
options.error_during_init = true;
|
||||
delegate_ = TfLiteDelegateFactory::CreateSimpleDelegate(
|
||||
auto delegate = TfLiteDelegateFactory::Create(
|
||||
std::make_unique<TestSimpleDelegate>(options));
|
||||
ASSERT_EQ(kTfLiteDelegateError,
|
||||
interpreter_->ModifyGraphWithDelegate(delegate_));
|
||||
interpreter_->ModifyGraphWithDelegate(std::move(delegate)));
|
||||
}
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
|
|
@ -337,10 +337,7 @@ TfLiteDriver::TfLiteDriver(DelegateType delegate_type, bool reference_kernel)
|
|||
break;
|
||||
case DelegateType::kFlex:
|
||||
#if !defined(__APPLE__)
|
||||
delegate_ = Interpreter::TfLiteDelegatePtr(
|
||||
FlexDelegate::Create().release(), [](TfLiteDelegate* delegate) {
|
||||
delete static_cast<tflite::FlexDelegate*>(delegate);
|
||||
});
|
||||
delegate_ = FlexDelegate::Create();
|
||||
#endif
|
||||
break;
|
||||
}
|
||||
|
|
|
@ -744,8 +744,9 @@ typedef struct TfLiteDelegate {
|
|||
struct TfLiteDelegate* delegate);
|
||||
|
||||
// Copy the data from delegate buffer handle into raw memory of the given
|
||||
// 'tensor'. This cannot be null. The delegate is allowed to allocate the raw
|
||||
// bytes as long as it follows the rules for kTfLiteDynamic tensors.
|
||||
// 'tensor'. Note that the delegate is allowed to allocate the raw bytes as
|
||||
// long as it follows the rules for kTfLiteDynamic tensors, in which case this
|
||||
// cannot be null.
|
||||
TfLiteStatus (*CopyFromBufferHandle)(TfLiteContext* context,
|
||||
struct TfLiteDelegate* delegate,
|
||||
TfLiteBufferHandle buffer_handle,
|
||||
|
|
Loading…
Reference in New Issue