1. Adapt flex delegate to the newly created SimpleDelegate APIs.

2. Change SimpleDelegate APIs accordingly as needed to adapt the flex delegate.
a. Create a `version` virtual function.
b. Add a new function acting as an initialization for TfLiteDelegate Prepare.
c. Return a std::unique_ptr TfLiteDegate for the factory method of TfLiteDelegateFactory for less error-prone memory management.

PiperOrigin-RevId: 314225150
Change-Id: I6bf7ab48f1e72f390b49dd5c6ac7ecd11d29327a
This commit is contained in:
Chao Mei 2020-06-01 16:23:43 -07:00 committed by TensorFlower Gardener
parent 24580ebdd9
commit edeae9fb69
15 changed files with 254 additions and 327 deletions

View File

@ -744,8 +744,9 @@ typedef struct TfLiteDelegate {
struct TfLiteDelegate* delegate); struct TfLiteDelegate* delegate);
// Copy the data from delegate buffer handle into raw memory of the given // Copy the data from delegate buffer handle into raw memory of the given
// 'tensor'. This cannot be null. The delegate is allowed to allocate the raw // 'tensor'. Note that the delegate is allowed to allocate the raw bytes as
// bytes as long as it follows the rules for kTfLiteDynamic tensors. // long as it follows the rules for kTfLiteDynamic tensors, in which case this
// cannot be null.
TfLiteStatus (*CopyFromBufferHandle)(TfLiteContext* context, TfLiteStatus (*CopyFromBufferHandle)(TfLiteContext* context,
struct TfLiteDelegate* delegate, struct TfLiteDelegate* delegate,
TfLiteBufferHandle buffer_handle, TfLiteBufferHandle buffer_handle,

View File

@ -61,6 +61,7 @@ cc_library(
deps = [ deps = [
":delegate_data", ":delegate_data",
":delegate_only_runtime", ":delegate_only_runtime",
"//tensorflow/lite/delegates/utils:simple_delegate",
] + select({ ] + select({
"//tensorflow:android": [ "//tensorflow:android": [
"//tensorflow/core:portable_tensorflow_lib", "//tensorflow/core:portable_tensorflow_lib",
@ -82,6 +83,8 @@ cc_library(
name = "delegate_only_runtime", name = "delegate_only_runtime",
srcs = [ srcs = [
"delegate.cc", "delegate.cc",
"kernel.cc",
"kernel.h",
], ],
hdrs = [ hdrs = [
"delegate.h", "delegate.h",
@ -90,14 +93,18 @@ cc_library(
deps = [ deps = [
":buffer_map", ":buffer_map",
":delegate_data", ":delegate_data",
":kernel",
":util", ":util",
"@flatbuffers",
"@com_google_absl//absl/strings:strings", "@com_google_absl//absl/strings:strings",
"//tensorflow/lite/core/api",
"//tensorflow/lite/c:common", "//tensorflow/lite/c:common",
"//tensorflow/lite:kernel_api", "//tensorflow/lite:kernel_api",
"//tensorflow/lite:minimal_logging", "//tensorflow/lite:minimal_logging",
"//tensorflow/lite:string",
"//tensorflow/lite:string_util", "//tensorflow/lite:string_util",
"//tensorflow/lite:util", "//tensorflow/lite:util",
"//tensorflow/lite/delegates/utils:simple_delegate",
"//tensorflow/lite/kernels:kernel_util",
] + select({ ] + select({
"//tensorflow:android": [ "//tensorflow:android": [
"//tensorflow/core:portable_tensorflow_lib_lite", "//tensorflow/core:portable_tensorflow_lib_lite",
@ -106,7 +113,12 @@ cc_library(
"//tensorflow/core:portable_tensorflow_lib_lite", "//tensorflow/core:portable_tensorflow_lib_lite",
], ],
"//conditions:default": [ "//conditions:default": [
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/common_runtime/eager:execute",
"//tensorflow/core/common_runtime/eager:tensor_handle",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:framework",
], ],
}), }),
alwayslink = 1, alwayslink = 1,
@ -163,40 +175,6 @@ tf_cc_test(
], ],
) )
cc_library(
name = "kernel",
srcs = ["kernel.cc"],
hdrs = ["kernel.h"],
deps = [
":delegate_data",
":util",
"@flatbuffers",
"//tensorflow/lite/core/api",
"//tensorflow/lite/c:common",
"//tensorflow/lite:kernel_api",
"//tensorflow/lite:string",
"//tensorflow/lite/kernels:kernel_util",
] + select({
# TODO(b/111881878): The android_tensorflow_lib target pulls in the full
# set of core TensorFlow kernels. We may want to revisit this dependency
# to allow selective registration via build targets.
"//tensorflow:android": [
"//tensorflow/core:portable_tensorflow_lib_lite",
],
"//tensorflow:ios": [
"//tensorflow/core:portable_tensorflow_lib_lite",
],
"//conditions:default": [
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/common_runtime/eager:execute",
"//tensorflow/core/common_runtime/eager:tensor_handle",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:framework",
],
}),
)
tf_cc_test( tf_cc_test(
name = "kernel_test", name = "kernel_test",
size = "small", size = "small",
@ -204,20 +182,10 @@ tf_cc_test(
tags = ["no_gpu"], # GPU + flex is not officially supported. tags = ["no_gpu"], # GPU + flex is not officially supported.
deps = [ deps = [
":delegate_data", ":delegate_data",
":kernel", ":delegate_only_runtime",
":test_util", ":test_util",
"@com_google_googletest//:gtest", "@com_google_googletest//:gtest",
] + select({ ],
"//tensorflow:android": [
"//tensorflow/core:portable_tensorflow_lib",
],
"//tensorflow:ios": [
"//tensorflow/core:portable_tensorflow_lib",
],
"//conditions:default": [
"//tensorflow/core:tensorflow",
],
}),
) )
cc_library( cc_library(

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/lite/delegates/flex/delegate.h" #include "tensorflow/lite/delegates/flex/delegate.h"
#include <memory>
#include <vector> #include <vector>
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
@ -27,10 +28,32 @@ limitations under the License.
#include "tensorflow/lite/util.h" #include "tensorflow/lite/util.h"
namespace tflite { namespace tflite {
namespace flex {
namespace delegate {
TfLiteStatus Prepare(TfLiteContext* context, TfLiteDelegate* delegate) { // Corresponding weak declaration found in lite/interpreter_builder.cc.
TfLiteDelegateUniquePtr AcquireFlexDelegate() {
return tflite::FlexDelegate::Create();
}
TfLiteDelegateUniquePtr FlexDelegate::Create(
std::unique_ptr<FlexDelegate> base_delegate) {
TFLITE_LOG_PROD_ONCE(TFLITE_LOG_INFO,
"Created TensorFlow Lite delegate for select TF ops.");
if (base_delegate == nullptr) {
base_delegate.reset(new FlexDelegate());
}
auto flex_delegate = TfLiteDelegateFactory::Create(std::move(base_delegate));
flex_delegate->CopyFromBufferHandle =
[](TfLiteContext* context, TfLiteDelegate* delegate,
TfLiteBufferHandle buffer_handle,
TfLiteTensor* tensor) -> TfLiteStatus {
return reinterpret_cast<FlexDelegate*>(delegate->data_)
->CopyFromBufferHandle(context, buffer_handle, tensor);
};
flex_delegate->flags |= kTfLiteDelegateFlagsAllowDynamicTensors;
return flex_delegate;
}
TfLiteStatus FlexDelegate::Initialize(TfLiteContext* context) {
// If the TensorFlow Lite thread count is explicitly configured, use it, // If the TensorFlow Lite thread count is explicitly configured, use it,
// otherwise rely on the default TensorFlow threading behavior. // otherwise rely on the default TensorFlow threading behavior.
tensorflow::SessionOptions session_options; tensorflow::SessionOptions session_options;
@ -39,47 +62,37 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteDelegate* delegate) {
context->recommended_num_threads); context->recommended_num_threads);
} }
auto status = reinterpret_cast<DelegateData*>(delegate->data_) auto status = delegate_data_.Prepare(session_options);
->Prepare(session_options);
if (!status.ok()) { if (!status.ok()) {
context->ReportError(context, "Failed to initialize TensorFlow context: %s", context->ReportError(context, "Failed to initialize TensorFlow context: %s",
status.error_message().c_str()); status.error_message().c_str());
return kTfLiteError; return kTfLiteError;
} }
// Get the nodes in the current execution plan. Interpreter owns this array.
TfLiteIntArray* plan;
TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan));
// Add all custom ops starting with "Flex" to list of supported nodes.
std::vector<int> supported_nodes;
for (int node_index : TfLiteIntArrayView(plan)) {
TfLiteNode* node;
TfLiteRegistration* registration;
TF_LITE_ENSURE_STATUS(context->GetNodeAndRegistration(
context, node_index, &node, &registration));
if (IsFlexOp(registration->custom_name)) {
supported_nodes.push_back(node_index);
}
}
// Request TFLite to partition the graph and make kernels for each independent
// node sub set.
TfLiteIntArray* size_and_nodes =
ConvertVectorToTfLiteIntArray(supported_nodes);
context->ReplaceNodeSubsetsWithDelegateKernels(context, GetKernel(),
size_and_nodes, delegate);
TfLiteIntArrayFree(size_and_nodes);
return kTfLiteOk; return kTfLiteOk;
} }
TfLiteStatus CopyFromBufferHandle(TfLiteContext* context, const char* FlexDelegate::name() const {
TfLiteDelegate* delegate, static constexpr char kName[] = "TfLiteFlexDelegate";
TfLiteBufferHandle buffer_handle, return kName;
TfLiteTensor* output) { }
BufferMap* buffer_map =
reinterpret_cast<DelegateData*>(delegate->data_)->GetBufferMap(context); bool FlexDelegate::IsNodeSupportedByDelegate(
const TfLiteRegistration* registration, const TfLiteNode* node,
TfLiteContext* context) const {
return IsFlexOp(registration->custom_name);
}
std::unique_ptr<SimpleDelegateKernelInterface>
FlexDelegate::CreateDelegateKernelInterface() {
return std::unique_ptr<SimpleDelegateKernelInterface>(
new tflite::flex::DelegateKernel());
}
TfLiteStatus FlexDelegate::CopyFromBufferHandle(
TfLiteContext* context, TfLiteBufferHandle buffer_handle,
TfLiteTensor* output) {
flex::BufferMap* buffer_map = delegate_data_.GetBufferMap(context);
if (!buffer_map->HasTensor(buffer_handle)) { if (!buffer_map->HasTensor(buffer_handle)) {
context->ReportError(context, "Invalid tensor index %d.", buffer_handle); context->ReportError(context, "Invalid tensor index %d.", buffer_handle);
@ -122,31 +135,4 @@ TfLiteStatus CopyFromBufferHandle(TfLiteContext* context,
return kTfLiteOk; return kTfLiteOk;
} }
} // namespace delegate
} // namespace flex
// Corresponding weak declaration found in lite/model.cc.
std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>
AcquireFlexDelegate() {
return std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>(
tflite::FlexDelegate::Create().release(), [](TfLiteDelegate* delegate) {
delete reinterpret_cast<tflite::FlexDelegate*>(delegate);
});
}
std::unique_ptr<FlexDelegate> FlexDelegate::Create() {
TFLITE_LOG_PROD_ONCE(TFLITE_LOG_INFO,
"Created TensorFlow Lite delegate for select TF ops.");
return std::unique_ptr<FlexDelegate>(new FlexDelegate());
}
FlexDelegate::FlexDelegate() : TfLiteDelegate(TfLiteDelegateCreate()) {
data_ = &delegate_data_;
Prepare = &flex::delegate::Prepare;
CopyFromBufferHandle = &flex::delegate::CopyFromBufferHandle;
flags = kTfLiteDelegateFlagsAllowDynamicTensors;
}
FlexDelegate::~FlexDelegate() {}
} // namespace tflite } // namespace tflite

View File

@ -17,9 +17,16 @@ limitations under the License.
#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/delegates/flex/delegate_data.h" #include "tensorflow/lite/delegates/flex/delegate_data.h"
#include "tensorflow/lite/delegates/utils/simple_delegate.h"
namespace tflite { namespace tflite {
namespace flex {
namespace testing {
class KernelTest;
} // namespace testing
} // namespace flex
// WARNING: This is an experimental interface that is subject to change. // WARNING: This is an experimental interface that is subject to change.
// Delegate that can be used to extract parts of a graph that are designed to be // Delegate that can be used to extract parts of a graph that are designed to be
// executed by TensorFlow's runtime via Eager. // executed by TensorFlow's runtime via Eager.
@ -33,22 +40,49 @@ namespace tflite {
// ... build interpreter ... // ... build interpreter ...
// //
// if (delegate) { // if (delegate) {
// interpreter->ModifyGraphWithDelegate(delegate.get()); // interpreter->ModifyGraphWithDelegate(std::move(delegate));
// } // }
// ... run inference ... // ... run inference ...
// ... destroy interpreter ... // ... destroy interpreter ...
// ... destroy delegate ... class FlexDelegate : public SimpleDelegateInterface {
class FlexDelegate : public TfLiteDelegate {
public: public:
friend class flex::testing::KernelTest;
// Creates a delegate that supports TF ops. // Creates a delegate that supports TF ops.
// static TfLiteDelegateUniquePtr Create() {
// If the underyling TF Flex context creation fails, returns null. return Create(/*base_delegate*/ nullptr);
static std::unique_ptr<FlexDelegate> Create(); }
~FlexDelegate(); ~FlexDelegate() override {}
private: flex::DelegateData* mutable_data() { return &delegate_data_; }
FlexDelegate();
protected:
// We sometimes have to create certain stub data to test FlexDelegate. To
// achieve this, we will make a testing flex delegate class that inherits from
// FlexDelegate to override certain things for stub data creation. Therefore,
// this function accepts a FlexDelegate instance to initiliaze it properly for
// create a testing flex delegate in some cases, and it is only used in
// testing.
static TfLiteDelegateUniquePtr Create(
std::unique_ptr<FlexDelegate> base_delegate);
FlexDelegate() {}
const char* name() const override;
bool IsNodeSupportedByDelegate(const TfLiteRegistration* registration,
const TfLiteNode* node,
TfLiteContext* context) const override;
TfLiteStatus Initialize(TfLiteContext* context) override;
std::unique_ptr<SimpleDelegateKernelInterface> CreateDelegateKernelInterface()
override;
TfLiteStatus CopyFromBufferHandle(TfLiteContext* context,
TfLiteBufferHandle buffer_handle,
TfLiteTensor* output);
flex::DelegateData delegate_data_; flex::DelegateData delegate_data_;
}; };

View File

@ -26,8 +26,7 @@ using ::testing::ElementsAre;
class DelegateTest : public testing::FlexModelTest { class DelegateTest : public testing::FlexModelTest {
public: public:
DelegateTest() { DelegateTest() : delegate_(FlexDelegate::Create()) {
delegate_ = FlexDelegate::Create();
interpreter_.reset(new Interpreter(&error_reporter_)); interpreter_.reset(new Interpreter(&error_reporter_));
} }
@ -44,7 +43,7 @@ class DelegateTest : public testing::FlexModelTest {
} }
private: private:
std::unique_ptr<FlexDelegate> delegate_; std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)> delegate_;
}; };
TEST_F(DelegateTest, FullGraph) { TEST_F(DelegateTest, FullGraph) {

View File

@ -18,6 +18,7 @@ cc_library(
], ],
deps = [ deps = [
"//tensorflow/lite/delegates/flex:delegate", "//tensorflow/lite/delegates/flex:delegate",
"//tensorflow/lite/delegates/utils:simple_delegate",
"//tensorflow/lite/java/jni", "//tensorflow/lite/java/jni",
"//tensorflow/lite/testing:init_tensorflow", "//tensorflow/lite/testing:init_tensorflow",
], ],

View File

@ -16,6 +16,7 @@ limitations under the License.
#include <jni.h> #include <jni.h>
#include "tensorflow/lite/delegates/flex/delegate.h" #include "tensorflow/lite/delegates/flex/delegate.h"
#include "tensorflow/lite/delegates/utils/simple_delegate.h"
#include "tensorflow/lite/testing/init_tensorflow.h" #include "tensorflow/lite/testing/init_tensorflow.h"
#ifdef __cplusplus #ifdef __cplusplus
@ -37,7 +38,8 @@ Java_org_tensorflow_lite_flex_FlexDelegate_nativeCreateDelegate(JNIEnv* env,
JNIEXPORT void JNICALL JNIEXPORT void JNICALL
Java_org_tensorflow_lite_flex_FlexDelegate_nativeDeleteDelegate( Java_org_tensorflow_lite_flex_FlexDelegate_nativeDeleteDelegate(
JNIEnv* env, jclass clazz, jlong delegate) { JNIEnv* env, jclass clazz, jlong delegate) {
delete reinterpret_cast<tflite::FlexDelegate*>(delegate); tflite::TfLiteDelegateFactory::DeleteSimpleDelegate(
reinterpret_cast<struct TfLiteDelegate*>(delegate));
} }
#ifdef __cplusplus #ifdef __cplusplus

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/context_util.h" #include "tensorflow/lite/context_util.h"
#include "tensorflow/lite/core/api/profiler.h" #include "tensorflow/lite/core/api/profiler.h"
#include "tensorflow/lite/delegates/flex/delegate.h"
#include "tensorflow/lite/delegates/flex/delegate_data.h" #include "tensorflow/lite/delegates/flex/delegate_data.h"
#include "tensorflow/lite/delegates/flex/util.h" #include "tensorflow/lite/delegates/flex/util.h"
#include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/kernel_util.h"
@ -49,7 +50,6 @@ limitations under the License.
namespace tflite { namespace tflite {
namespace flex { namespace flex {
namespace kernel {
struct OpNode; struct OpNode;
@ -357,33 +357,29 @@ struct OpData {
std::vector<int> subgraph_outputs; std::vector<int> subgraph_outputs;
}; };
void* Init(TfLiteContext* context, const char* buffer, size_t length) { DelegateKernel::DelegateKernel() : op_data_(new OpData) {}
auto* op_data = new OpData; DelegateKernel::~DelegateKernel() {}
const TfLiteDelegateParams* params = TfLiteStatus DelegateKernel::Init(TfLiteContext* context,
reinterpret_cast<const TfLiteDelegateParams*>(buffer); const TfLiteDelegateParams* params) {
CHECK(params); auto* flex_delegate_data =
CHECK(params->delegate); reinterpret_cast<FlexDelegate*>(params->delegate->data_)->mutable_data();
CHECK(params->delegate->data_); op_data_->eager_context = flex_delegate_data->GetEagerContext();
op_data->eager_context = op_data_->buffer_map = flex_delegate_data->GetBufferMap(context);
reinterpret_cast<DelegateData*>(params->delegate->data_)
->GetEagerContext();
op_data->buffer_map = reinterpret_cast<DelegateData*>(params->delegate->data_)
->GetBufferMap(context);
CHECK(params->output_tensors); CHECK(params->output_tensors);
std::set<int> output_set; std::set<int> output_set;
for (auto tensor_index : TfLiteIntArrayView(params->output_tensors)) { for (auto tensor_index : TfLiteIntArrayView(params->output_tensors)) {
op_data->subgraph_outputs.push_back(tensor_index); op_data_->subgraph_outputs.push_back(tensor_index);
output_set.insert(tensor_index); output_set.insert(tensor_index);
} }
CHECK(params->input_tensors); CHECK(params->input_tensors);
for (auto tensor_index : TfLiteIntArrayView(params->input_tensors)) { for (auto tensor_index : TfLiteIntArrayView(params->input_tensors)) {
op_data->subgraph_inputs.push_back(tensor_index); op_data_->subgraph_inputs.push_back(tensor_index);
} }
op_data->nodes.reserve(params->nodes_to_replace->size); op_data_->nodes.reserve(params->nodes_to_replace->size);
CHECK(params->nodes_to_replace); CHECK(params->nodes_to_replace);
tensorflow::Status status; tensorflow::Status status;
@ -392,8 +388,8 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
TfLiteRegistration* reg; TfLiteRegistration* reg;
context->GetNodeAndRegistration(context, node_index, &node, &reg); context->GetNodeAndRegistration(context, node_index, &node, &reg);
op_data->nodes.emplace_back(new OpNode(node->inputs, node->outputs)); op_data_->nodes.emplace_back(new OpNode(node->inputs, node->outputs));
OpNode& node_data = *op_data->nodes.back(); OpNode& node_data = *op_data_->nodes.back();
node_data.set_index(node_index); node_data.set_index(node_index);
node_data.set_name(""); node_data.set_name("");
@ -401,16 +397,11 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
status = node_data.InitializeNodeDef(node->custom_initial_data, status = node_data.InitializeNodeDef(node->custom_initial_data,
node->custom_initial_data_size); node->custom_initial_data_size);
if (!status.ok()) break; if (!status.ok()) break;
status = node_data.BuildEagerOp(op_data->eager_context); status = node_data.BuildEagerOp(op_data_->eager_context);
if (!status.ok()) break; if (!status.ok()) break;
} }
if (ConvertStatus(context, status) != kTfLiteOk) { TF_LITE_ENSURE_STATUS(ConvertStatus(context, status));
// We can't return an error from this function but ConvertStatus will
// report them and we will stop processing in Prepare() if anything went
// wrong.
return op_data;
}
// Given a TfLite tensor index, return the OpNode that produces it, // Given a TfLite tensor index, return the OpNode that produces it,
// along with it index into that OpNodes list of outputs. // along with it index into that OpNodes list of outputs.
@ -418,7 +409,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
// Find out how each tensor is produced. This does not account for // Find out how each tensor is produced. This does not account for
// tensors that are not produce by eager ops. // tensors that are not produce by eager ops.
for (auto& node_data : op_data->nodes) { for (auto& node_data : op_data_->nodes) {
node_data->mutable_outputs()->InitializeGraphOutputs(output_set); node_data->mutable_outputs()->InitializeGraphOutputs(output_set);
for (int i = 0; i < node_data->outputs().Size(); ++i) { for (int i = 0; i < node_data->outputs().Size(); ++i) {
int output_index = node_data->outputs().TfLiteIndex(i); int output_index = node_data->outputs().TfLiteIndex(i);
@ -428,21 +419,15 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
// For each node, resolve the inputs, so we can keep pointers to the nodes // For each node, resolve the inputs, so we can keep pointers to the nodes
// that produces them. // that produces them.
for (auto& node_data : op_data->nodes) { for (auto& node_data : op_data_->nodes) {
node_data->mutable_inputs()->InitializeTensorSources(tflite_tensor_sources); node_data->mutable_inputs()->InitializeTensorSources(tflite_tensor_sources);
} }
return kTfLiteOk;
return op_data;
} }
void Free(TfLiteContext* context, void* buffer) { TfLiteStatus DelegateKernel::Prepare(TfLiteContext* context, TfLiteNode* node) {
delete reinterpret_cast<OpData*>(buffer);
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const auto* op_data = reinterpret_cast<OpData*>(node->user_data);
TF_LITE_ENSURE_MSG( TF_LITE_ENSURE_MSG(
context, op_data->eager_context != nullptr, context, op_data_->eager_context != nullptr,
"Failed to initialize eager context. This often happens when a CPU " "Failed to initialize eager context. This often happens when a CPU "
"device has not been registered, presumably because some symbols from " "device has not been registered, presumably because some symbols from "
"tensorflow/core:core_cpu_impl were not linked into the binary."); "tensorflow/core:core_cpu_impl were not linked into the binary.");
@ -452,8 +437,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
std::map<int, int> tensor_ref_count; std::map<int, int> tensor_ref_count;
// Whenever we find a constant tensor, insert it in the buffer map. // Whenever we find a constant tensor, insert it in the buffer map.
BufferMap* buffer_map = op_data->buffer_map; BufferMap* buffer_map = op_data_->buffer_map;
for (auto tensor_index : op_data->subgraph_inputs) { for (auto tensor_index : op_data_->subgraph_inputs) {
TfLiteTensor* tensor = &context->tensors[tensor_index]; TfLiteTensor* tensor = &context->tensors[tensor_index];
if (IsConstantTensor(tensor)) { if (IsConstantTensor(tensor)) {
if (!buffer_map->HasTensor(tensor_index)) { if (!buffer_map->HasTensor(tensor_index)) {
@ -469,12 +454,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// All output tensors are allocated by TensorFlow/Eager, so we // All output tensors are allocated by TensorFlow/Eager, so we
// mark them as kTfLiteDynamic. // mark them as kTfLiteDynamic.
for (auto tensor_index : op_data->subgraph_outputs) { for (auto tensor_index : op_data_->subgraph_outputs) {
SetTensorToDynamic(&context->tensors[tensor_index]); SetTensorToDynamic(&context->tensors[tensor_index]);
++tensor_ref_count[tensor_index]; ++tensor_ref_count[tensor_index];
} }
for (const auto& node_data : op_data->nodes) { for (const auto& node_data : op_data_->nodes) {
if (node_data->nodedef().op().empty()) { if (node_data->nodedef().op().empty()) {
context->ReportError(context, "Invalid NodeDef in Flex op '%s'", context->ReportError(context, "Invalid NodeDef in Flex op '%s'",
node_data->name().c_str()); node_data->name().c_str());
@ -490,7 +475,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// All tensors that are referenced exactly once are marked as "forwardable", // All tensors that are referenced exactly once are marked as "forwardable",
// meaning that we will allow TensorFlow to reuse its buffer as the output of // meaning that we will allow TensorFlow to reuse its buffer as the output of
// an op. // an op.
for (auto& node_data : op_data->nodes) { for (auto& node_data : op_data_->nodes) {
for (int i = 0; i < node_data->inputs().Size(); ++i) { for (int i = 0; i < node_data->inputs().Size(); ++i) {
bool f = (tensor_ref_count[node_data->inputs().TfLiteIndex(i)] == 1); bool f = (tensor_ref_count[node_data->inputs().TfLiteIndex(i)] == 1);
node_data->mutable_inputs()->SetForwardable(i, f); node_data->mutable_inputs()->SetForwardable(i, f);
@ -500,13 +485,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk; return kTfLiteOk;
} }
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus DelegateKernel::Invoke(TfLiteContext* context, TfLiteNode* node) {
auto* op_data = reinterpret_cast<OpData*>(node->user_data); BufferMap* buffer_map = op_data_->buffer_map;
BufferMap* buffer_map = op_data->buffer_map;
// Insert a tensor in the buffer map for all inputs that are not constant. // Insert a tensor in the buffer map for all inputs that are not constant.
// Constants were handled in Prepare() already. // Constants were handled in Prepare() already.
for (auto tensor_index : op_data->subgraph_inputs) { for (auto tensor_index : op_data_->subgraph_inputs) {
TfLiteTensor* tensor = &context->tensors[tensor_index]; TfLiteTensor* tensor = &context->tensors[tensor_index];
if (!IsConstantTensor(tensor)) { if (!IsConstantTensor(tensor)) {
// If this tensor is part of an earlier TF subgraph we should not add it // If this tensor is part of an earlier TF subgraph we should not add it
@ -519,7 +503,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} }
// Execute the TensorFlow Ops sequentially. // Execute the TensorFlow Ops sequentially.
for (auto& node_data : op_data->nodes) { for (auto& node_data : op_data_->nodes) {
TFLITE_SCOPED_DELEGATE_OPERATOR_PROFILE( TFLITE_SCOPED_DELEGATE_OPERATOR_PROFILE(
reinterpret_cast<Profiler*>(context->profiler), reinterpret_cast<Profiler*>(context->profiler),
node_data->name().c_str(), node_data->index()); node_data->name().c_str(), node_data->index());
@ -528,7 +512,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, ConvertStatus(context, status)); TF_LITE_ENSURE_OK(context, ConvertStatus(context, status));
} }
for (auto tensor_index : op_data->subgraph_outputs) { for (auto tensor_index : op_data_->subgraph_outputs) {
if (!buffer_map->HasTensor(tensor_index)) { if (!buffer_map->HasTensor(tensor_index)) {
context->ReportError(context, "Cannot write to invalid tensor index %d", context->ReportError(context, "Cannot write to invalid tensor index %d",
tensor_index); tensor_index);
@ -546,21 +530,5 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk; return kTfLiteOk;
} }
} // namespace kernel
TfLiteRegistration GetKernel() {
TfLiteRegistration registration{
&kernel::Init,
&kernel::Free,
&kernel::Prepare,
&kernel::Eval,
nullptr, // .profiling_string
kTfLiteBuiltinDelegate, // .builtin_code
"TfLiteFlexDelegate", // .custom_name
1, // .version
};
return registration;
}
} // namespace flex } // namespace flex
} // namespace tflite } // namespace tflite

View File

@ -15,18 +15,28 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_DELEGATES_FLEX_KERNEL_H_ #ifndef TENSORFLOW_LITE_DELEGATES_FLEX_KERNEL_H_
#define TENSORFLOW_LITE_DELEGATES_FLEX_KERNEL_H_ #define TENSORFLOW_LITE_DELEGATES_FLEX_KERNEL_H_
#include <memory>
#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/delegates/utils/simple_delegate.h"
namespace tflite { namespace tflite {
namespace flex { namespace flex {
// Return the registration object used to initialize and execute ops that will struct OpData;
// be delegated to TensorFlow's Eager runtime. This TF Lite op is created by class DelegateKernel : public SimpleDelegateKernelInterface {
// the flex delegate to handle execution of a supported subgraph. The usual public:
// flow is that the delegate informs the interpreter of supported nodes in a DelegateKernel();
// graph, and each supported subgraph is replaced with one instance of this ~DelegateKernel() override;
// kernel.
TfLiteRegistration GetKernel(); TfLiteStatus Init(TfLiteContext* context,
const TfLiteDelegateParams* params) override;
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) override;
TfLiteStatus Invoke(TfLiteContext* context, TfLiteNode* node) override;
private:
std::unique_ptr<OpData> op_data_;
};
} // namespace flex } // namespace flex
} // namespace tflite } // namespace tflite

View File

@ -12,38 +12,30 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/lite/delegates/flex/kernel.h"
#include <gmock/gmock.h> #include <gmock/gmock.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "tensorflow/lite/delegates/flex/delegate.h"
#include "tensorflow/lite/delegates/flex/delegate_data.h" #include "tensorflow/lite/delegates/flex/delegate_data.h"
#include "tensorflow/lite/delegates/flex/test_util.h" #include "tensorflow/lite/delegates/flex/test_util.h"
namespace tflite { namespace tflite {
namespace flex { namespace flex {
namespace { namespace testing {
using ::testing::ContainsRegex; using ::testing::ContainsRegex;
using ::testing::ElementsAre; using ::testing::ElementsAre;
using ::testing::ElementsAreArray; using ::testing::ElementsAreArray;
TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteDelegate* delegate, // A testing flex delegate that supports every node regardless whether it's
const std::vector<int>& supported_nodes) { // actually supported or not. It's only for testing certain scenarios.
TfLiteIntArray* size_and_nodes = class TestFlexDelegate : public FlexDelegate {
ConvertVectorToTfLiteIntArray(supported_nodes); protected:
TF_LITE_ENSURE_STATUS(context->ReplaceNodeSubsetsWithDelegateKernels( bool IsNodeSupportedByDelegate(const TfLiteRegistration* registration,
context, flex::GetKernel(), size_and_nodes, delegate)); const TfLiteNode* node,
TfLiteIntArrayFree(size_and_nodes); TfLiteContext* context) const override {
return kTfLiteOk; return true;
} }
};
// There is no easy way to pass a parameter into the TfLiteDelegate's
// 'prepare' function, so we keep a global map for testing purposed.
// To avoid collisions use: GetPrepareFunction<__LINE__>().
std::map<int, std::vector<int>>* GetGlobalOpLists() {
static auto* op_list = new std::map<int, std::vector<int>>;
return op_list;
}
class KernelTest : public testing::FlexModelTest { class KernelTest : public testing::FlexModelTest {
public: public:
@ -51,51 +43,16 @@ class KernelTest : public testing::FlexModelTest {
static constexpr int kTwos = 2; // This is the index of a tensor of 2's. static constexpr int kTwos = 2; // This is the index of a tensor of 2's.
static constexpr int kMaxTensors = 30; static constexpr int kMaxTensors = 30;
static void SetUpTestSuite() { GetGlobalOpLists()->clear(); } KernelTest() { interpreter_.reset(new Interpreter(&error_reporter_)); }
KernelTest() { void ApplyFlexDelegate(std::unique_ptr<FlexDelegate> delegate = nullptr) {
CHECK(delegate_data_.Prepare(tensorflow::SessionOptions{}).ok()); auto flex_delegate = FlexDelegate::Create(std::move(delegate));
interpreter_.reset(new Interpreter(&error_reporter_)); auto* delegate_data =
reinterpret_cast<FlexDelegate*>(flex_delegate->data_)->mutable_data();
CHECK(delegate_data->Prepare(tensorflow::SessionOptions{}).ok());
CHECK(interpreter_->ModifyGraphWithDelegate(std::move(flex_delegate)) ==
kTfLiteOk);
} }
typedef TfLiteStatus (*PrepareFunction)(TfLiteContext* context,
TfLiteDelegate* delegate);
template <int KEY>
PrepareFunction GetPrepareFunction() {
GetGlobalOpLists()->insert({KEY, tf_ops_});
return [](TfLiteContext* context, TfLiteDelegate* delegate) {
return GenericPrepare(context, delegate, GetGlobalOpLists()->at(KEY));
};
}
template <typename T>
void ConfigureDelegate(T prepare_function) {
delegate_.data_ = &delegate_data_;
delegate_.flags = kTfLiteDelegateFlagsAllowDynamicTensors;
delegate_.FreeBufferHandle = nullptr;
delegate_.Prepare = prepare_function;
delegate_.CopyFromBufferHandle = [](TfLiteContext* context,
TfLiteDelegate* delegate,
TfLiteBufferHandle buffer_handle,
TfLiteTensor* output) {
auto* delegate_data = reinterpret_cast<DelegateData*>(delegate->data_);
auto* buffer_map = delegate_data->GetBufferMap(context);
if (!buffer_map->HasTensor(buffer_handle)) {
context->ReportError(context, "Tensor '%d' not found", buffer_handle);
return kTfLiteError;
}
tensorflow::StringPiece values =
buffer_map->GetTensor(buffer_handle).tensor_data();
memcpy(output->data.raw, values.data(), values.size());
return kTfLiteOk;
};
CHECK(interpreter_->ModifyGraphWithDelegate(&delegate_) == kTfLiteOk);
}
private:
DelegateData delegate_data_;
TfLiteDelegate delegate_;
}; };
TEST_F(KernelTest, FullGraph) { TEST_F(KernelTest, FullGraph) {
@ -108,10 +65,7 @@ TEST_F(KernelTest, FullGraph) {
AddTfOp(testing::kAdd, {2, 5}, {7}); AddTfOp(testing::kAdd, {2, 5}, {7});
AddTfOp(testing::kMul, {6, 7}, {8}); AddTfOp(testing::kMul, {6, 7}, {8});
// Apply Delegate. ApplyFlexDelegate();
ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) {
return GenericPrepare(context, delegate, {0, 1, 2, 3, 4});
});
// Define inputs. // Define inputs.
SetShape(0, {2, 2, 1}); SetShape(0, {2, 2, 1});
@ -140,9 +94,7 @@ TEST_F(KernelTest, BadTensorFlowOp) {
AddTensors(2, {0}, {1}, kTfLiteFloat32, {3}); AddTensors(2, {0}, {1}, kTfLiteFloat32, {3});
AddTfOp(testing::kNonExistent, {0}, {1}); AddTfOp(testing::kNonExistent, {0}, {1});
ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) { ApplyFlexDelegate(std::unique_ptr<FlexDelegate>(new TestFlexDelegate()));
return GenericPrepare(context, delegate, {0});
});
ASSERT_NE(interpreter_->AllocateTensors(), kTfLiteOk); ASSERT_NE(interpreter_->AllocateTensors(), kTfLiteOk);
ASSERT_THAT(error_reporter().error_messages(), ASSERT_THAT(error_reporter().error_messages(),
@ -153,9 +105,7 @@ TEST_F(KernelTest, BadNumberOfOutputs) {
AddTensors(3, {0}, {1, 2}, kTfLiteFloat32, {3}); AddTensors(3, {0}, {1, 2}, kTfLiteFloat32, {3});
AddTfOp(testing::kIdentity, {0}, {1, 2}); AddTfOp(testing::kIdentity, {0}, {1, 2});
ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) { ApplyFlexDelegate();
return GenericPrepare(context, delegate, {0});
});
SetShape(0, {2, 2, 1}); SetShape(0, {2, 2, 1});
SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f}); SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
@ -171,9 +121,7 @@ TEST_F(KernelTest, IncompatibleNodeDef) {
// Cast is a TF op, but we don't add the proper nodedef to it in AddTfOp. // Cast is a TF op, but we don't add the proper nodedef to it in AddTfOp.
AddTfOp(testing::kIncompatibleNodeDef, {0}, {1}); AddTfOp(testing::kIncompatibleNodeDef, {0}, {1});
ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) { ApplyFlexDelegate();
return GenericPrepare(context, delegate, {0});
});
SetShape(0, {2, 2, 1}); SetShape(0, {2, 2, 1});
SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f}); SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
@ -188,14 +136,14 @@ TEST_F(KernelTest, WrongSetOfNodes) {
AddTfOp(testing::kUnpack, {0}, {1, 2}); AddTfOp(testing::kUnpack, {0}, {1, 2});
AddTfLiteMulOp({1, 2}, {3}); AddTfLiteMulOp({1, 2}, {3});
// Specify that testing::kMul (#1) is supported when it actually isn't. // Specify that testing::kMul (#1) is supported when it actually isn't so that
ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) { // we choose to use the TestFlexDelegate that supports every node regardless
return GenericPrepare(context, delegate, {0, 1}); // whether it's actually supported or not.
}); ApplyFlexDelegate(std::unique_ptr<FlexDelegate>(new TestFlexDelegate()));
ASSERT_NE(interpreter_->AllocateTensors(), kTfLiteOk); ASSERT_NE(interpreter_->AllocateTensors(), kTfLiteOk);
ASSERT_THAT(error_reporter().error_messages(), ASSERT_THAT(error_reporter().error_messages(),
ContainsRegex("Invalid NodeDef in Flex op")); ContainsRegex("Cannot convert empty data into a valid NodeDef"));
} }
TEST_F(KernelTest, MixedGraph) { TEST_F(KernelTest, MixedGraph) {
@ -207,9 +155,7 @@ TEST_F(KernelTest, MixedGraph) {
AddTfOp(testing::kAdd, {2, 5}, {7}); AddTfOp(testing::kAdd, {2, 5}, {7});
AddTfLiteMulOp({6, 7}, {8}); AddTfLiteMulOp({6, 7}, {8});
ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) { ApplyFlexDelegate();
return GenericPrepare(context, delegate, {0, 1, 2, 3});
});
SetShape(0, {2, 2, 1}); SetShape(0, {2, 2, 1});
SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f}); SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
@ -251,14 +197,7 @@ TEST_F(KernelTest, SplitGraph) {
// The two branches added together: // The two branches added together:
AddTfOp(testing::kAdd, {9, 16}, {17}); // => 16 AddTfOp(testing::kAdd, {9, 16}, {17}); // => 16
ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) { ApplyFlexDelegate();
// All ops but #3 are TF ops, handled by the delegate. However, because #4
// depends on the non-TF op, two subgraphs are necessary:
// TF subgraph 1: 0, 1, 2, 6, 7, 8, 9
// TF Lite Op: 3
// TF subgraph 2: 4, 5, 10
return GenericPrepare(context, delegate, {0, 1, 2, 4, 5, 6, 7, 8, 9, 10});
});
SetShape(0, {2, 2, 2, 1}); SetShape(0, {2, 2, 2, 1});
SetValues(0, a); SetValues(0, a);
@ -291,9 +230,8 @@ class MultipleSubgraphsTest : public KernelTest {
public: public:
static constexpr int kInput = 0; static constexpr int kInput = 0;
void PrepareInterpreter(PrepareFunction prepare, void PrepareInterpreter(const std::vector<float>& input) {
const std::vector<float>& input) { ApplyFlexDelegate();
ConfigureDelegate(prepare);
SetShape(kOnes, {3}); SetShape(kOnes, {3});
SetValues(kOnes, {1.0f, 1.0f, 1.0f}); SetValues(kOnes, {1.0f, 1.0f, 1.0f});
@ -336,7 +274,7 @@ TEST_F(MultipleSubgraphsTest, ForwardabilityIsLocal) {
AddTfLiteMulOp({10, 7}, {12}); AddTfLiteMulOp({10, 7}, {12});
auto input = {3.0f, 4.0f, 5.0f}; auto input = {3.0f, 4.0f, 5.0f};
PrepareInterpreter(GetPrepareFunction<__LINE__>(), input); PrepareInterpreter(input);
ASSERT_TRUE(Invoke()); ASSERT_TRUE(Invoke());
ASSERT_THAT(GetValues(12), ElementsAreArray(Apply(input, [](float in) { ASSERT_THAT(GetValues(12), ElementsAreArray(Apply(input, [](float in) {
@ -371,7 +309,7 @@ TEST_F(MultipleSubgraphsTest, DoNotRemoveInputTensors) {
AddTfLiteMulOp({10, 7}, {12}); AddTfLiteMulOp({10, 7}, {12});
auto input = {3.0f, 4.0f, 5.0f}; auto input = {3.0f, 4.0f, 5.0f};
PrepareInterpreter(GetPrepareFunction<__LINE__>(), input); PrepareInterpreter(input);
ASSERT_TRUE(Invoke()); ASSERT_TRUE(Invoke());
ASSERT_THAT(GetValues(12), ElementsAreArray(Apply(input, [](float in) { ASSERT_THAT(GetValues(12), ElementsAreArray(Apply(input, [](float in) {
@ -405,7 +343,7 @@ TEST_F(MultipleSubgraphsTest, DoNotForwardInputTensors) {
AddTfLiteMulOp({10, 7}, {12}); AddTfLiteMulOp({10, 7}, {12});
auto input = {3.0f, 4.0f, 5.0f}; auto input = {3.0f, 4.0f, 5.0f};
PrepareInterpreter(GetPrepareFunction<__LINE__>(), input); PrepareInterpreter(input);
ASSERT_TRUE(Invoke()); ASSERT_TRUE(Invoke());
ASSERT_THAT(GetValues(12), ElementsAreArray(Apply(input, [](float in) { ASSERT_THAT(GetValues(12), ElementsAreArray(Apply(input, [](float in) {
@ -413,7 +351,7 @@ TEST_F(MultipleSubgraphsTest, DoNotForwardInputTensors) {
}))); })));
} }
} // namespace } // namespace testing
} // namespace flex } // namespace flex
} // namespace tflite } // namespace tflite

View File

@ -32,6 +32,7 @@ TfLiteRegistration GetDelegateKernelRegistration(
kernel_registration.profiling_string = nullptr; kernel_registration.profiling_string = nullptr;
kernel_registration.builtin_code = kTfLiteBuiltinDelegate; kernel_registration.builtin_code = kTfLiteBuiltinDelegate;
kernel_registration.custom_name = delegate->name(); kernel_registration.custom_name = delegate->name();
kernel_registration.version = 1;
kernel_registration.free = [](TfLiteContext* context, void* buffer) -> void { kernel_registration.free = [](TfLiteContext* context, void* buffer) -> void {
delete reinterpret_cast<SimpleDelegateKernelInterface*>(buffer); delete reinterpret_cast<SimpleDelegateKernelInterface*>(buffer);
}; };
@ -77,6 +78,7 @@ TfLiteStatus DelegatePrepare(TfLiteContext* context,
TfLiteDelegate* base_delegate) { TfLiteDelegate* base_delegate) {
auto* delegate = auto* delegate =
reinterpret_cast<SimpleDelegateInterface*>(base_delegate->data_); reinterpret_cast<SimpleDelegateInterface*>(base_delegate->data_);
TF_LITE_ENSURE_STATUS(delegate->Initialize(context));
delegates::IsNodeSupportedFn node_supported_fn = delegates::IsNodeSupportedFn node_supported_fn =
[=](TfLiteContext* context, TfLiteNode* node, [=](TfLiteContext* context, TfLiteNode* node,
TfLiteRegistration* registration, TfLiteRegistration* registration,
@ -125,5 +127,4 @@ void TfLiteDelegateFactory::DeleteSimpleDelegate(TfLiteDelegate* delegate) {
delete simple_delegate; delete simple_delegate;
delete delegate; delete delegate;
} }
} // namespace tflite } // namespace tflite

View File

@ -20,8 +20,12 @@ limitations under the License.
// this interface to build/prepare/invoke the delegated subgraph. // this interface to build/prepare/invoke the delegated subgraph.
// - SimpleDelegateInterface: // - SimpleDelegateInterface:
// This class wraps TFLiteDelegate and users need to implement the interface and // This class wraps TFLiteDelegate and users need to implement the interface and
// then Call GetFinalizedDelegate() to get TfLiteDelegate* that can be passed to // then call TfLiteDelegateFactory::CreateSimpleDelegate(...) to get
// ModifyGraphWithDelegate. // TfLiteDelegate* that can be passed to ModifyGraphWithDelegate and free it via
// TfLiteDelegateFactory::DeleteSimpleDelegate(...).
// or call TfLiteDelegateFactory::Create(...) to get a std::unique_ptr
// TfLiteDelegate that can also be passed to ModifyGraphWithDelegate, in which
// case TfLite interpereter takes the memory ownership of the delegate.
#ifndef TENSORFLOW_LITE_DELEGATES_UTILS_SIMPLE_DELEGATE_H_ #ifndef TENSORFLOW_LITE_DELEGATES_UTILS_SIMPLE_DELEGATE_H_
#define TENSORFLOW_LITE_DELEGATES_UTILS_SIMPLE_DELEGATE_H_ #define TENSORFLOW_LITE_DELEGATES_UTILS_SIMPLE_DELEGATE_H_
@ -31,6 +35,9 @@ limitations under the License.
namespace tflite { namespace tflite {
using TfLiteDelegateUniquePtr =
std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>;
// Users should inherit from this class and implement the interface below. // Users should inherit from this class and implement the interface below.
// Each instance represents a single part of the graph (subgraph). // Each instance represents a single part of the graph (subgraph).
class SimpleDelegateKernelInterface { class SimpleDelegateKernelInterface {
@ -49,6 +56,7 @@ class SimpleDelegateKernelInterface {
// Actual subgraph inference should happen on this call. // Actual subgraph inference should happen on this call.
// Returns status, and signalling any errors. // Returns status, and signalling any errors.
// TODO(b/157882025): change this to Eval to be consistent w/ a TFLite kernel.
virtual TfLiteStatus Invoke(TfLiteContext* context, TfLiteNode* node) = 0; virtual TfLiteStatus Invoke(TfLiteContext* context, TfLiteNode* node) = 0;
}; };
@ -58,6 +66,7 @@ class SimpleDelegateKernelInterface {
// //
// Clients should implement the following methods: // Clients should implement the following methods:
// - IsNodeSupportedByDelegate // - IsNodeSupportedByDelegate
// - Initialize
// - name // - name
// - CreateDelegateKernelInterface // - CreateDelegateKernelInterface
class SimpleDelegateInterface { class SimpleDelegateInterface {
@ -71,8 +80,14 @@ class SimpleDelegateInterface {
const TfLiteNode* node, const TfLiteNode* node,
TfLiteContext* context) const = 0; TfLiteContext* context) const = 0;
// Initialize the delegate before finding and replacing TfLite nodes with
// delegate kernels, for example, retrieving some TFLite settings from
// 'context'.
virtual TfLiteStatus Initialize(TfLiteContext* context) = 0;
// Returns a name that identifies the delegate. // Returns a name that identifies the delegate.
// This name is used for debugging/logging/profiling. // This name is used for debugging/logging/profiling.
// TODO(b/157882025): change this to Name()
virtual const char* name() const = 0; virtual const char* name() const = 0;
// Returns instance of an object that implements the interface // Returns instance of an object that implements the interface
@ -84,13 +99,8 @@ class SimpleDelegateInterface {
CreateDelegateKernelInterface() = 0; CreateDelegateKernelInterface() = 0;
}; };
// Factory class that provides two static methods // Factory class that provides static methods to deal with SimpleDelegate
// CreateSimpleDelegate // creation and deletion.
// DeleteSimpleDelegate
// Which should be used to construct TfLiteDelegate from
// Simple Delegate and delete TfLiteDelegate and SimpleDelegate give
// tfLiteDelegate* created from 'CreateSimpleDelegate' method.
// Users should use these methods to Create and Destroy the delegate.
class TfLiteDelegateFactory { class TfLiteDelegateFactory {
public: public:
// Creates TfLiteDelegate from the provided SimpleDelegateInterface. // Creates TfLiteDelegate from the provided SimpleDelegateInterface.
@ -99,9 +109,17 @@ class TfLiteDelegateFactory {
std::unique_ptr<SimpleDelegateInterface> simple_delegate); std::unique_ptr<SimpleDelegateInterface> simple_delegate);
// Deletes 'delegate' the passed pointer must be the one returned // Deletes 'delegate' the passed pointer must be the one returned
// from GetFinalizedDelegate. // from CreateSimpleDelegate.
// This function will destruct the SimpleDelegate object too. // This function will destruct the SimpleDelegate object too.
static void DeleteSimpleDelegate(TfLiteDelegate* delegate); static void DeleteSimpleDelegate(TfLiteDelegate* delegate);
// A convenient function wrapping the above two functions and returning a
// std::unique_ptr type for auto memory management.
inline static TfLiteDelegateUniquePtr Create(
std::unique_ptr<SimpleDelegateInterface> simple_delegate) {
return TfLiteDelegateUniquePtr(
CreateSimpleDelegate(std::move(simple_delegate)), DeleteSimpleDelegate);
}
}; };
} // namespace tflite } // namespace tflite

View File

@ -72,7 +72,12 @@ class TestSimpleDelegate : public SimpleDelegateInterface {
return options_.allowed_builtin_code == registration->builtin_code; return options_.allowed_builtin_code == registration->builtin_code;
} }
const char* name() const override { return "TestSimpleDelegate"; } TfLiteStatus Initialize(TfLiteContext* context) override { return kTfLiteOk; }
const char* name() const override {
static constexpr char kName[] = "TestSimpleDelegate";
return kName;
}
std::unique_ptr<SimpleDelegateKernelInterface> CreateDelegateKernelInterface() std::unique_ptr<SimpleDelegateKernelInterface> CreateDelegateKernelInterface()
override { override {
@ -113,27 +118,24 @@ class TestDelegate : public ::testing::Test {
reg); reg);
} }
void TearDown() override { void TearDown() override { interpreter_.reset(); }
interpreter_.reset();
TfLiteDelegateFactory::DeleteSimpleDelegate(delegate_);
}
protected: protected:
std::unique_ptr<Interpreter> interpreter_; std::unique_ptr<Interpreter> interpreter_;
TfLiteDelegate* delegate_ = nullptr;
}; };
TEST_F(TestDelegate, BasicDelegate) { TEST_F(TestDelegate, BasicDelegate) {
TestSimpleDelegateOptions options; TestSimpleDelegateOptions options;
options.allowed_builtin_code = kTfLiteBuiltinAdd; options.allowed_builtin_code = kTfLiteBuiltinAdd;
delegate_ = TfLiteDelegateFactory::CreateSimpleDelegate( auto delegate = TfLiteDelegateFactory::Create(
std::make_unique<TestSimpleDelegate>(options)); std::make_unique<TestSimpleDelegate>(options));
interpreter_->ModifyGraphWithDelegate(delegate_); interpreter_->ModifyGraphWithDelegate(std::move(delegate));
ASSERT_EQ(interpreter_->execution_plan().size(), 1); ASSERT_EQ(interpreter_->execution_plan().size(), 1);
int node = interpreter_->execution_plan()[0]; int node = interpreter_->execution_plan()[0];
const auto* node_and_reg = interpreter_->node_and_registration(node); const auto* node_and_reg = interpreter_->node_and_registration(node);
EXPECT_EQ("TestSimpleDelegate", node_and_reg->second.custom_name); EXPECT_STREQ("TestSimpleDelegate", node_and_reg->second.custom_name);
EXPECT_EQ(1, node_and_reg->second.version);
const TfLiteDelegateParams* params = static_cast<const TfLiteDelegateParams*>( const TfLiteDelegateParams* params = static_cast<const TfLiteDelegateParams*>(
node_and_reg->first.builtin_data); node_and_reg->first.builtin_data);
@ -154,9 +156,9 @@ TEST_F(TestDelegate, BasicDelegate) {
TEST_F(TestDelegate, NoNodesToDelegate) { TEST_F(TestDelegate, NoNodesToDelegate) {
TestSimpleDelegateOptions options; TestSimpleDelegateOptions options;
options.allowed_builtin_code = kTfLiteBuiltinSub; options.allowed_builtin_code = kTfLiteBuiltinSub;
delegate_ = TfLiteDelegateFactory::CreateSimpleDelegate( auto delegate = TfLiteDelegateFactory::Create(
std::make_unique<TestSimpleDelegate>(options)); std::make_unique<TestSimpleDelegate>(options));
interpreter_->ModifyGraphWithDelegate(delegate_); interpreter_->ModifyGraphWithDelegate(std::move(delegate));
ASSERT_EQ(interpreter_->execution_plan().size(), 3); ASSERT_EQ(interpreter_->execution_plan().size(), 3);
} }
@ -165,19 +167,20 @@ TEST_F(TestDelegate, DelegateFailedPrepare) {
TestSimpleDelegateOptions options; TestSimpleDelegateOptions options;
options.allowed_builtin_code = kTfLiteBuiltinAdd; options.allowed_builtin_code = kTfLiteBuiltinAdd;
options.error_during_prepare = true; options.error_during_prepare = true;
delegate_ = TfLiteDelegateFactory::CreateSimpleDelegate( auto delegate = TfLiteDelegateFactory::Create(
std::make_unique<TestSimpleDelegate>(options)); std::make_unique<TestSimpleDelegate>(options));
ASSERT_EQ(kTfLiteDelegateError, ASSERT_EQ(kTfLiteDelegateError,
interpreter_->ModifyGraphWithDelegate(delegate_)); interpreter_->ModifyGraphWithDelegate(std::move(delegate)));
} }
TEST_F(TestDelegate, DelegateFailedInvoke) { TEST_F(TestDelegate, DelegateFailedInvoke) {
TestSimpleDelegateOptions options; TestSimpleDelegateOptions options;
options.allowed_builtin_code = kTfLiteBuiltinAdd; options.allowed_builtin_code = kTfLiteBuiltinAdd;
options.error_during_invoke = true; options.error_during_invoke = true;
delegate_ = TfLiteDelegateFactory::CreateSimpleDelegate( auto delegate = TfLiteDelegateFactory::Create(
std::make_unique<TestSimpleDelegate>(options)); std::make_unique<TestSimpleDelegate>(options));
ASSERT_EQ(kTfLiteOk, interpreter_->ModifyGraphWithDelegate(delegate_)); ASSERT_EQ(kTfLiteOk,
interpreter_->ModifyGraphWithDelegate(std::move(delegate)));
ASSERT_EQ(kTfLiteError, interpreter_->Invoke()); ASSERT_EQ(kTfLiteError, interpreter_->Invoke());
} }
@ -185,10 +188,10 @@ TEST_F(TestDelegate, DelegateFailedInit) {
TestSimpleDelegateOptions options; TestSimpleDelegateOptions options;
options.allowed_builtin_code = kTfLiteBuiltinAdd; options.allowed_builtin_code = kTfLiteBuiltinAdd;
options.error_during_init = true; options.error_during_init = true;
delegate_ = TfLiteDelegateFactory::CreateSimpleDelegate( auto delegate = TfLiteDelegateFactory::Create(
std::make_unique<TestSimpleDelegate>(options)); std::make_unique<TestSimpleDelegate>(options));
ASSERT_EQ(kTfLiteDelegateError, ASSERT_EQ(kTfLiteDelegateError,
interpreter_->ModifyGraphWithDelegate(delegate_)); interpreter_->ModifyGraphWithDelegate(std::move(delegate)));
} }
} // namespace } // namespace
} // namespace tflite } // namespace tflite

View File

@ -337,10 +337,7 @@ TfLiteDriver::TfLiteDriver(DelegateType delegate_type, bool reference_kernel)
break; break;
case DelegateType::kFlex: case DelegateType::kFlex:
#if !defined(__APPLE__) #if !defined(__APPLE__)
delegate_ = Interpreter::TfLiteDelegatePtr( delegate_ = FlexDelegate::Create();
FlexDelegate::Create().release(), [](TfLiteDelegate* delegate) {
delete static_cast<tflite::FlexDelegate*>(delegate);
});
#endif #endif
break; break;
} }

View File

@ -744,8 +744,9 @@ typedef struct TfLiteDelegate {
struct TfLiteDelegate* delegate); struct TfLiteDelegate* delegate);
// Copy the data from delegate buffer handle into raw memory of the given // Copy the data from delegate buffer handle into raw memory of the given
// 'tensor'. This cannot be null. The delegate is allowed to allocate the raw // 'tensor'. Note that the delegate is allowed to allocate the raw bytes as
// bytes as long as it follows the rules for kTfLiteDynamic tensors. // long as it follows the rules for kTfLiteDynamic tensors, in which case this
// cannot be null.
TfLiteStatus (*CopyFromBufferHandle)(TfLiteContext* context, TfLiteStatus (*CopyFromBufferHandle)(TfLiteContext* context,
struct TfLiteDelegate* delegate, struct TfLiteDelegate* delegate,
TfLiteBufferHandle buffer_handle, TfLiteBufferHandle buffer_handle,