1. Adapt flex delegate to the newly created SimpleDelegate APIs.

2. Change SimpleDelegate APIs accordingly as needed to adapt the flex delegate.
a. Create a `version` virtual function.
b. Add a new function acting as an initialization for TfLiteDelegate Prepare.
c. Return a std::unique_ptr TfLiteDegate for the factory method of TfLiteDelegateFactory for less error-prone memory management.

PiperOrigin-RevId: 314225150
Change-Id: I6bf7ab48f1e72f390b49dd5c6ac7ecd11d29327a
This commit is contained in:
Chao Mei 2020-06-01 16:23:43 -07:00 committed by TensorFlower Gardener
parent 24580ebdd9
commit edeae9fb69
15 changed files with 254 additions and 327 deletions

View File

@ -744,8 +744,9 @@ typedef struct TfLiteDelegate {
struct TfLiteDelegate* delegate);
// Copy the data from delegate buffer handle into raw memory of the given
// 'tensor'. This cannot be null. The delegate is allowed to allocate the raw
// bytes as long as it follows the rules for kTfLiteDynamic tensors.
// 'tensor'. Note that the delegate is allowed to allocate the raw bytes as
// long as it follows the rules for kTfLiteDynamic tensors, in which case this
// cannot be null.
TfLiteStatus (*CopyFromBufferHandle)(TfLiteContext* context,
struct TfLiteDelegate* delegate,
TfLiteBufferHandle buffer_handle,

View File

@ -61,6 +61,7 @@ cc_library(
deps = [
":delegate_data",
":delegate_only_runtime",
"//tensorflow/lite/delegates/utils:simple_delegate",
] + select({
"//tensorflow:android": [
"//tensorflow/core:portable_tensorflow_lib",
@ -82,6 +83,8 @@ cc_library(
name = "delegate_only_runtime",
srcs = [
"delegate.cc",
"kernel.cc",
"kernel.h",
],
hdrs = [
"delegate.h",
@ -90,14 +93,18 @@ cc_library(
deps = [
":buffer_map",
":delegate_data",
":kernel",
":util",
"@flatbuffers",
"@com_google_absl//absl/strings:strings",
"//tensorflow/lite/core/api",
"//tensorflow/lite/c:common",
"//tensorflow/lite:kernel_api",
"//tensorflow/lite:minimal_logging",
"//tensorflow/lite:string",
"//tensorflow/lite:string_util",
"//tensorflow/lite:util",
"//tensorflow/lite/delegates/utils:simple_delegate",
"//tensorflow/lite/kernels:kernel_util",
] + select({
"//tensorflow:android": [
"//tensorflow/core:portable_tensorflow_lib_lite",
@ -106,7 +113,12 @@ cc_library(
"//tensorflow/core:portable_tensorflow_lib_lite",
],
"//conditions:default": [
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/common_runtime/eager:execute",
"//tensorflow/core/common_runtime/eager:tensor_handle",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:framework",
],
}),
alwayslink = 1,
@ -163,40 +175,6 @@ tf_cc_test(
],
)
cc_library(
name = "kernel",
srcs = ["kernel.cc"],
hdrs = ["kernel.h"],
deps = [
":delegate_data",
":util",
"@flatbuffers",
"//tensorflow/lite/core/api",
"//tensorflow/lite/c:common",
"//tensorflow/lite:kernel_api",
"//tensorflow/lite:string",
"//tensorflow/lite/kernels:kernel_util",
] + select({
# TODO(b/111881878): The android_tensorflow_lib target pulls in the full
# set of core TensorFlow kernels. We may want to revisit this dependency
# to allow selective registration via build targets.
"//tensorflow:android": [
"//tensorflow/core:portable_tensorflow_lib_lite",
],
"//tensorflow:ios": [
"//tensorflow/core:portable_tensorflow_lib_lite",
],
"//conditions:default": [
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/common_runtime/eager:execute",
"//tensorflow/core/common_runtime/eager:tensor_handle",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:framework",
],
}),
)
tf_cc_test(
name = "kernel_test",
size = "small",
@ -204,20 +182,10 @@ tf_cc_test(
tags = ["no_gpu"], # GPU + flex is not officially supported.
deps = [
":delegate_data",
":kernel",
":delegate_only_runtime",
":test_util",
"@com_google_googletest//:gtest",
] + select({
"//tensorflow:android": [
"//tensorflow/core:portable_tensorflow_lib",
],
"//tensorflow:ios": [
"//tensorflow/core:portable_tensorflow_lib",
],
"//conditions:default": [
"//tensorflow/core:tensorflow",
],
}),
],
)
cc_library(

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/flex/delegate.h"
#include <memory>
#include <vector>
#include "absl/strings/str_cat.h"
@ -27,10 +28,32 @@ limitations under the License.
#include "tensorflow/lite/util.h"
namespace tflite {
namespace flex {
namespace delegate {
TfLiteStatus Prepare(TfLiteContext* context, TfLiteDelegate* delegate) {
// Corresponding weak declaration found in lite/interpreter_builder.cc.
TfLiteDelegateUniquePtr AcquireFlexDelegate() {
return tflite::FlexDelegate::Create();
}
TfLiteDelegateUniquePtr FlexDelegate::Create(
std::unique_ptr<FlexDelegate> base_delegate) {
TFLITE_LOG_PROD_ONCE(TFLITE_LOG_INFO,
"Created TensorFlow Lite delegate for select TF ops.");
if (base_delegate == nullptr) {
base_delegate.reset(new FlexDelegate());
}
auto flex_delegate = TfLiteDelegateFactory::Create(std::move(base_delegate));
flex_delegate->CopyFromBufferHandle =
[](TfLiteContext* context, TfLiteDelegate* delegate,
TfLiteBufferHandle buffer_handle,
TfLiteTensor* tensor) -> TfLiteStatus {
return reinterpret_cast<FlexDelegate*>(delegate->data_)
->CopyFromBufferHandle(context, buffer_handle, tensor);
};
flex_delegate->flags |= kTfLiteDelegateFlagsAllowDynamicTensors;
return flex_delegate;
}
TfLiteStatus FlexDelegate::Initialize(TfLiteContext* context) {
// If the TensorFlow Lite thread count is explicitly configured, use it,
// otherwise rely on the default TensorFlow threading behavior.
tensorflow::SessionOptions session_options;
@ -39,47 +62,37 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteDelegate* delegate) {
context->recommended_num_threads);
}
auto status = reinterpret_cast<DelegateData*>(delegate->data_)
->Prepare(session_options);
auto status = delegate_data_.Prepare(session_options);
if (!status.ok()) {
context->ReportError(context, "Failed to initialize TensorFlow context: %s",
status.error_message().c_str());
return kTfLiteError;
}
// Get the nodes in the current execution plan. Interpreter owns this array.
TfLiteIntArray* plan;
TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan));
// Add all custom ops starting with "Flex" to list of supported nodes.
std::vector<int> supported_nodes;
for (int node_index : TfLiteIntArrayView(plan)) {
TfLiteNode* node;
TfLiteRegistration* registration;
TF_LITE_ENSURE_STATUS(context->GetNodeAndRegistration(
context, node_index, &node, &registration));
if (IsFlexOp(registration->custom_name)) {
supported_nodes.push_back(node_index);
}
}
// Request TFLite to partition the graph and make kernels for each independent
// node sub set.
TfLiteIntArray* size_and_nodes =
ConvertVectorToTfLiteIntArray(supported_nodes);
context->ReplaceNodeSubsetsWithDelegateKernels(context, GetKernel(),
size_and_nodes, delegate);
TfLiteIntArrayFree(size_and_nodes);
return kTfLiteOk;
}
TfLiteStatus CopyFromBufferHandle(TfLiteContext* context,
TfLiteDelegate* delegate,
TfLiteBufferHandle buffer_handle,
TfLiteTensor* output) {
BufferMap* buffer_map =
reinterpret_cast<DelegateData*>(delegate->data_)->GetBufferMap(context);
const char* FlexDelegate::name() const {
static constexpr char kName[] = "TfLiteFlexDelegate";
return kName;
}
bool FlexDelegate::IsNodeSupportedByDelegate(
const TfLiteRegistration* registration, const TfLiteNode* node,
TfLiteContext* context) const {
return IsFlexOp(registration->custom_name);
}
std::unique_ptr<SimpleDelegateKernelInterface>
FlexDelegate::CreateDelegateKernelInterface() {
return std::unique_ptr<SimpleDelegateKernelInterface>(
new tflite::flex::DelegateKernel());
}
TfLiteStatus FlexDelegate::CopyFromBufferHandle(
TfLiteContext* context, TfLiteBufferHandle buffer_handle,
TfLiteTensor* output) {
flex::BufferMap* buffer_map = delegate_data_.GetBufferMap(context);
if (!buffer_map->HasTensor(buffer_handle)) {
context->ReportError(context, "Invalid tensor index %d.", buffer_handle);
@ -122,31 +135,4 @@ TfLiteStatus CopyFromBufferHandle(TfLiteContext* context,
return kTfLiteOk;
}
} // namespace delegate
} // namespace flex
// Corresponding weak declaration found in lite/model.cc.
std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>
AcquireFlexDelegate() {
return std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>(
tflite::FlexDelegate::Create().release(), [](TfLiteDelegate* delegate) {
delete reinterpret_cast<tflite::FlexDelegate*>(delegate);
});
}
std::unique_ptr<FlexDelegate> FlexDelegate::Create() {
TFLITE_LOG_PROD_ONCE(TFLITE_LOG_INFO,
"Created TensorFlow Lite delegate for select TF ops.");
return std::unique_ptr<FlexDelegate>(new FlexDelegate());
}
FlexDelegate::FlexDelegate() : TfLiteDelegate(TfLiteDelegateCreate()) {
data_ = &delegate_data_;
Prepare = &flex::delegate::Prepare;
CopyFromBufferHandle = &flex::delegate::CopyFromBufferHandle;
flags = kTfLiteDelegateFlagsAllowDynamicTensors;
}
FlexDelegate::~FlexDelegate() {}
} // namespace tflite

View File

@ -17,9 +17,16 @@ limitations under the License.
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/delegates/flex/delegate_data.h"
#include "tensorflow/lite/delegates/utils/simple_delegate.h"
namespace tflite {
namespace flex {
namespace testing {
class KernelTest;
} // namespace testing
} // namespace flex
// WARNING: This is an experimental interface that is subject to change.
// Delegate that can be used to extract parts of a graph that are designed to be
// executed by TensorFlow's runtime via Eager.
@ -33,22 +40,49 @@ namespace tflite {
// ... build interpreter ...
//
// if (delegate) {
// interpreter->ModifyGraphWithDelegate(delegate.get());
// interpreter->ModifyGraphWithDelegate(std::move(delegate));
// }
// ... run inference ...
// ... destroy interpreter ...
// ... destroy delegate ...
class FlexDelegate : public TfLiteDelegate {
class FlexDelegate : public SimpleDelegateInterface {
public:
friend class flex::testing::KernelTest;
// Creates a delegate that supports TF ops.
//
// If the underyling TF Flex context creation fails, returns null.
static std::unique_ptr<FlexDelegate> Create();
static TfLiteDelegateUniquePtr Create() {
return Create(/*base_delegate*/ nullptr);
}
~FlexDelegate();
~FlexDelegate() override {}
private:
FlexDelegate();
flex::DelegateData* mutable_data() { return &delegate_data_; }
protected:
// We sometimes have to create certain stub data to test FlexDelegate. To
// achieve this, we will make a testing flex delegate class that inherits from
// FlexDelegate to override certain things for stub data creation. Therefore,
// this function accepts a FlexDelegate instance to initiliaze it properly for
// create a testing flex delegate in some cases, and it is only used in
// testing.
static TfLiteDelegateUniquePtr Create(
std::unique_ptr<FlexDelegate> base_delegate);
FlexDelegate() {}
const char* name() const override;
bool IsNodeSupportedByDelegate(const TfLiteRegistration* registration,
const TfLiteNode* node,
TfLiteContext* context) const override;
TfLiteStatus Initialize(TfLiteContext* context) override;
std::unique_ptr<SimpleDelegateKernelInterface> CreateDelegateKernelInterface()
override;
TfLiteStatus CopyFromBufferHandle(TfLiteContext* context,
TfLiteBufferHandle buffer_handle,
TfLiteTensor* output);
flex::DelegateData delegate_data_;
};

View File

@ -26,8 +26,7 @@ using ::testing::ElementsAre;
class DelegateTest : public testing::FlexModelTest {
public:
DelegateTest() {
delegate_ = FlexDelegate::Create();
DelegateTest() : delegate_(FlexDelegate::Create()) {
interpreter_.reset(new Interpreter(&error_reporter_));
}
@ -44,7 +43,7 @@ class DelegateTest : public testing::FlexModelTest {
}
private:
std::unique_ptr<FlexDelegate> delegate_;
std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)> delegate_;
};
TEST_F(DelegateTest, FullGraph) {

View File

@ -18,6 +18,7 @@ cc_library(
],
deps = [
"//tensorflow/lite/delegates/flex:delegate",
"//tensorflow/lite/delegates/utils:simple_delegate",
"//tensorflow/lite/java/jni",
"//tensorflow/lite/testing:init_tensorflow",
],

View File

@ -16,6 +16,7 @@ limitations under the License.
#include <jni.h>
#include "tensorflow/lite/delegates/flex/delegate.h"
#include "tensorflow/lite/delegates/utils/simple_delegate.h"
#include "tensorflow/lite/testing/init_tensorflow.h"
#ifdef __cplusplus
@ -37,7 +38,8 @@ Java_org_tensorflow_lite_flex_FlexDelegate_nativeCreateDelegate(JNIEnv* env,
JNIEXPORT void JNICALL
Java_org_tensorflow_lite_flex_FlexDelegate_nativeDeleteDelegate(
JNIEnv* env, jclass clazz, jlong delegate) {
delete reinterpret_cast<tflite::FlexDelegate*>(delegate);
tflite::TfLiteDelegateFactory::DeleteSimpleDelegate(
reinterpret_cast<struct TfLiteDelegate*>(delegate));
}
#ifdef __cplusplus

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/context_util.h"
#include "tensorflow/lite/core/api/profiler.h"
#include "tensorflow/lite/delegates/flex/delegate.h"
#include "tensorflow/lite/delegates/flex/delegate_data.h"
#include "tensorflow/lite/delegates/flex/util.h"
#include "tensorflow/lite/kernels/kernel_util.h"
@ -49,7 +50,6 @@ limitations under the License.
namespace tflite {
namespace flex {
namespace kernel {
struct OpNode;
@ -357,33 +357,29 @@ struct OpData {
std::vector<int> subgraph_outputs;
};
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
auto* op_data = new OpData;
DelegateKernel::DelegateKernel() : op_data_(new OpData) {}
DelegateKernel::~DelegateKernel() {}
const TfLiteDelegateParams* params =
reinterpret_cast<const TfLiteDelegateParams*>(buffer);
CHECK(params);
CHECK(params->delegate);
CHECK(params->delegate->data_);
op_data->eager_context =
reinterpret_cast<DelegateData*>(params->delegate->data_)
->GetEagerContext();
op_data->buffer_map = reinterpret_cast<DelegateData*>(params->delegate->data_)
->GetBufferMap(context);
TfLiteStatus DelegateKernel::Init(TfLiteContext* context,
const TfLiteDelegateParams* params) {
auto* flex_delegate_data =
reinterpret_cast<FlexDelegate*>(params->delegate->data_)->mutable_data();
op_data_->eager_context = flex_delegate_data->GetEagerContext();
op_data_->buffer_map = flex_delegate_data->GetBufferMap(context);
CHECK(params->output_tensors);
std::set<int> output_set;
for (auto tensor_index : TfLiteIntArrayView(params->output_tensors)) {
op_data->subgraph_outputs.push_back(tensor_index);
op_data_->subgraph_outputs.push_back(tensor_index);
output_set.insert(tensor_index);
}
CHECK(params->input_tensors);
for (auto tensor_index : TfLiteIntArrayView(params->input_tensors)) {
op_data->subgraph_inputs.push_back(tensor_index);
op_data_->subgraph_inputs.push_back(tensor_index);
}
op_data->nodes.reserve(params->nodes_to_replace->size);
op_data_->nodes.reserve(params->nodes_to_replace->size);
CHECK(params->nodes_to_replace);
tensorflow::Status status;
@ -392,8 +388,8 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
TfLiteRegistration* reg;
context->GetNodeAndRegistration(context, node_index, &node, &reg);
op_data->nodes.emplace_back(new OpNode(node->inputs, node->outputs));
OpNode& node_data = *op_data->nodes.back();
op_data_->nodes.emplace_back(new OpNode(node->inputs, node->outputs));
OpNode& node_data = *op_data_->nodes.back();
node_data.set_index(node_index);
node_data.set_name("");
@ -401,16 +397,11 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
status = node_data.InitializeNodeDef(node->custom_initial_data,
node->custom_initial_data_size);
if (!status.ok()) break;
status = node_data.BuildEagerOp(op_data->eager_context);
status = node_data.BuildEagerOp(op_data_->eager_context);
if (!status.ok()) break;
}
if (ConvertStatus(context, status) != kTfLiteOk) {
// We can't return an error from this function but ConvertStatus will
// report them and we will stop processing in Prepare() if anything went
// wrong.
return op_data;
}
TF_LITE_ENSURE_STATUS(ConvertStatus(context, status));
// Given a TfLite tensor index, return the OpNode that produces it,
// along with it index into that OpNodes list of outputs.
@ -418,7 +409,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
// Find out how each tensor is produced. This does not account for
// tensors that are not produce by eager ops.
for (auto& node_data : op_data->nodes) {
for (auto& node_data : op_data_->nodes) {
node_data->mutable_outputs()->InitializeGraphOutputs(output_set);
for (int i = 0; i < node_data->outputs().Size(); ++i) {
int output_index = node_data->outputs().TfLiteIndex(i);
@ -428,21 +419,15 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
// For each node, resolve the inputs, so we can keep pointers to the nodes
// that produces them.
for (auto& node_data : op_data->nodes) {
for (auto& node_data : op_data_->nodes) {
node_data->mutable_inputs()->InitializeTensorSources(tflite_tensor_sources);
}
return op_data;
return kTfLiteOk;
}
void Free(TfLiteContext* context, void* buffer) {
delete reinterpret_cast<OpData*>(buffer);
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const auto* op_data = reinterpret_cast<OpData*>(node->user_data);
TfLiteStatus DelegateKernel::Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_MSG(
context, op_data->eager_context != nullptr,
context, op_data_->eager_context != nullptr,
"Failed to initialize eager context. This often happens when a CPU "
"device has not been registered, presumably because some symbols from "
"tensorflow/core:core_cpu_impl were not linked into the binary.");
@ -452,8 +437,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
std::map<int, int> tensor_ref_count;
// Whenever we find a constant tensor, insert it in the buffer map.
BufferMap* buffer_map = op_data->buffer_map;
for (auto tensor_index : op_data->subgraph_inputs) {
BufferMap* buffer_map = op_data_->buffer_map;
for (auto tensor_index : op_data_->subgraph_inputs) {
TfLiteTensor* tensor = &context->tensors[tensor_index];
if (IsConstantTensor(tensor)) {
if (!buffer_map->HasTensor(tensor_index)) {
@ -469,12 +454,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// All output tensors are allocated by TensorFlow/Eager, so we
// mark them as kTfLiteDynamic.
for (auto tensor_index : op_data->subgraph_outputs) {
for (auto tensor_index : op_data_->subgraph_outputs) {
SetTensorToDynamic(&context->tensors[tensor_index]);
++tensor_ref_count[tensor_index];
}
for (const auto& node_data : op_data->nodes) {
for (const auto& node_data : op_data_->nodes) {
if (node_data->nodedef().op().empty()) {
context->ReportError(context, "Invalid NodeDef in Flex op '%s'",
node_data->name().c_str());
@ -490,7 +475,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// All tensors that are referenced exactly once are marked as "forwardable",
// meaning that we will allow TensorFlow to reuse its buffer as the output of
// an op.
for (auto& node_data : op_data->nodes) {
for (auto& node_data : op_data_->nodes) {
for (int i = 0; i < node_data->inputs().Size(); ++i) {
bool f = (tensor_ref_count[node_data->inputs().TfLiteIndex(i)] == 1);
node_data->mutable_inputs()->SetForwardable(i, f);
@ -500,13 +485,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* op_data = reinterpret_cast<OpData*>(node->user_data);
BufferMap* buffer_map = op_data->buffer_map;
TfLiteStatus DelegateKernel::Invoke(TfLiteContext* context, TfLiteNode* node) {
BufferMap* buffer_map = op_data_->buffer_map;
// Insert a tensor in the buffer map for all inputs that are not constant.
// Constants were handled in Prepare() already.
for (auto tensor_index : op_data->subgraph_inputs) {
for (auto tensor_index : op_data_->subgraph_inputs) {
TfLiteTensor* tensor = &context->tensors[tensor_index];
if (!IsConstantTensor(tensor)) {
// If this tensor is part of an earlier TF subgraph we should not add it
@ -519,7 +503,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
}
// Execute the TensorFlow Ops sequentially.
for (auto& node_data : op_data->nodes) {
for (auto& node_data : op_data_->nodes) {
TFLITE_SCOPED_DELEGATE_OPERATOR_PROFILE(
reinterpret_cast<Profiler*>(context->profiler),
node_data->name().c_str(), node_data->index());
@ -528,7 +512,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, ConvertStatus(context, status));
}
for (auto tensor_index : op_data->subgraph_outputs) {
for (auto tensor_index : op_data_->subgraph_outputs) {
if (!buffer_map->HasTensor(tensor_index)) {
context->ReportError(context, "Cannot write to invalid tensor index %d",
tensor_index);
@ -546,21 +530,5 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
} // namespace kernel
TfLiteRegistration GetKernel() {
TfLiteRegistration registration{
&kernel::Init,
&kernel::Free,
&kernel::Prepare,
&kernel::Eval,
nullptr, // .profiling_string
kTfLiteBuiltinDelegate, // .builtin_code
"TfLiteFlexDelegate", // .custom_name
1, // .version
};
return registration;
}
} // namespace flex
} // namespace tflite

View File

@ -15,18 +15,28 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_DELEGATES_FLEX_KERNEL_H_
#define TENSORFLOW_LITE_DELEGATES_FLEX_KERNEL_H_
#include <memory>
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/delegates/utils/simple_delegate.h"
namespace tflite {
namespace flex {
// Return the registration object used to initialize and execute ops that will
// be delegated to TensorFlow's Eager runtime. This TF Lite op is created by
// the flex delegate to handle execution of a supported subgraph. The usual
// flow is that the delegate informs the interpreter of supported nodes in a
// graph, and each supported subgraph is replaced with one instance of this
// kernel.
TfLiteRegistration GetKernel();
struct OpData;
class DelegateKernel : public SimpleDelegateKernelInterface {
public:
DelegateKernel();
~DelegateKernel() override;
TfLiteStatus Init(TfLiteContext* context,
const TfLiteDelegateParams* params) override;
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) override;
TfLiteStatus Invoke(TfLiteContext* context, TfLiteNode* node) override;
private:
std::unique_ptr<OpData> op_data_;
};
} // namespace flex
} // namespace tflite

View File

@ -12,38 +12,30 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/flex/kernel.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/lite/delegates/flex/delegate.h"
#include "tensorflow/lite/delegates/flex/delegate_data.h"
#include "tensorflow/lite/delegates/flex/test_util.h"
namespace tflite {
namespace flex {
namespace {
namespace testing {
using ::testing::ContainsRegex;
using ::testing::ElementsAre;
using ::testing::ElementsAreArray;
TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteDelegate* delegate,
const std::vector<int>& supported_nodes) {
TfLiteIntArray* size_and_nodes =
ConvertVectorToTfLiteIntArray(supported_nodes);
TF_LITE_ENSURE_STATUS(context->ReplaceNodeSubsetsWithDelegateKernels(
context, flex::GetKernel(), size_and_nodes, delegate));
TfLiteIntArrayFree(size_and_nodes);
return kTfLiteOk;
}
// There is no easy way to pass a parameter into the TfLiteDelegate's
// 'prepare' function, so we keep a global map for testing purposed.
// To avoid collisions use: GetPrepareFunction<__LINE__>().
std::map<int, std::vector<int>>* GetGlobalOpLists() {
static auto* op_list = new std::map<int, std::vector<int>>;
return op_list;
}
// A testing flex delegate that supports every node regardless whether it's
// actually supported or not. It's only for testing certain scenarios.
class TestFlexDelegate : public FlexDelegate {
protected:
bool IsNodeSupportedByDelegate(const TfLiteRegistration* registration,
const TfLiteNode* node,
TfLiteContext* context) const override {
return true;
}
};
class KernelTest : public testing::FlexModelTest {
public:
@ -51,51 +43,16 @@ class KernelTest : public testing::FlexModelTest {
static constexpr int kTwos = 2; // This is the index of a tensor of 2's.
static constexpr int kMaxTensors = 30;
static void SetUpTestSuite() { GetGlobalOpLists()->clear(); }
KernelTest() { interpreter_.reset(new Interpreter(&error_reporter_)); }
KernelTest() {
CHECK(delegate_data_.Prepare(tensorflow::SessionOptions{}).ok());
interpreter_.reset(new Interpreter(&error_reporter_));
void ApplyFlexDelegate(std::unique_ptr<FlexDelegate> delegate = nullptr) {
auto flex_delegate = FlexDelegate::Create(std::move(delegate));
auto* delegate_data =
reinterpret_cast<FlexDelegate*>(flex_delegate->data_)->mutable_data();
CHECK(delegate_data->Prepare(tensorflow::SessionOptions{}).ok());
CHECK(interpreter_->ModifyGraphWithDelegate(std::move(flex_delegate)) ==
kTfLiteOk);
}
typedef TfLiteStatus (*PrepareFunction)(TfLiteContext* context,
TfLiteDelegate* delegate);
template <int KEY>
PrepareFunction GetPrepareFunction() {
GetGlobalOpLists()->insert({KEY, tf_ops_});
return [](TfLiteContext* context, TfLiteDelegate* delegate) {
return GenericPrepare(context, delegate, GetGlobalOpLists()->at(KEY));
};
}
template <typename T>
void ConfigureDelegate(T prepare_function) {
delegate_.data_ = &delegate_data_;
delegate_.flags = kTfLiteDelegateFlagsAllowDynamicTensors;
delegate_.FreeBufferHandle = nullptr;
delegate_.Prepare = prepare_function;
delegate_.CopyFromBufferHandle = [](TfLiteContext* context,
TfLiteDelegate* delegate,
TfLiteBufferHandle buffer_handle,
TfLiteTensor* output) {
auto* delegate_data = reinterpret_cast<DelegateData*>(delegate->data_);
auto* buffer_map = delegate_data->GetBufferMap(context);
if (!buffer_map->HasTensor(buffer_handle)) {
context->ReportError(context, "Tensor '%d' not found", buffer_handle);
return kTfLiteError;
}
tensorflow::StringPiece values =
buffer_map->GetTensor(buffer_handle).tensor_data();
memcpy(output->data.raw, values.data(), values.size());
return kTfLiteOk;
};
CHECK(interpreter_->ModifyGraphWithDelegate(&delegate_) == kTfLiteOk);
}
private:
DelegateData delegate_data_;
TfLiteDelegate delegate_;
};
TEST_F(KernelTest, FullGraph) {
@ -108,10 +65,7 @@ TEST_F(KernelTest, FullGraph) {
AddTfOp(testing::kAdd, {2, 5}, {7});
AddTfOp(testing::kMul, {6, 7}, {8});
// Apply Delegate.
ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) {
return GenericPrepare(context, delegate, {0, 1, 2, 3, 4});
});
ApplyFlexDelegate();
// Define inputs.
SetShape(0, {2, 2, 1});
@ -140,9 +94,7 @@ TEST_F(KernelTest, BadTensorFlowOp) {
AddTensors(2, {0}, {1}, kTfLiteFloat32, {3});
AddTfOp(testing::kNonExistent, {0}, {1});
ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) {
return GenericPrepare(context, delegate, {0});
});
ApplyFlexDelegate(std::unique_ptr<FlexDelegate>(new TestFlexDelegate()));
ASSERT_NE(interpreter_->AllocateTensors(), kTfLiteOk);
ASSERT_THAT(error_reporter().error_messages(),
@ -153,9 +105,7 @@ TEST_F(KernelTest, BadNumberOfOutputs) {
AddTensors(3, {0}, {1, 2}, kTfLiteFloat32, {3});
AddTfOp(testing::kIdentity, {0}, {1, 2});
ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) {
return GenericPrepare(context, delegate, {0});
});
ApplyFlexDelegate();
SetShape(0, {2, 2, 1});
SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
@ -171,9 +121,7 @@ TEST_F(KernelTest, IncompatibleNodeDef) {
// Cast is a TF op, but we don't add the proper nodedef to it in AddTfOp.
AddTfOp(testing::kIncompatibleNodeDef, {0}, {1});
ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) {
return GenericPrepare(context, delegate, {0});
});
ApplyFlexDelegate();
SetShape(0, {2, 2, 1});
SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
@ -188,14 +136,14 @@ TEST_F(KernelTest, WrongSetOfNodes) {
AddTfOp(testing::kUnpack, {0}, {1, 2});
AddTfLiteMulOp({1, 2}, {3});
// Specify that testing::kMul (#1) is supported when it actually isn't.
ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) {
return GenericPrepare(context, delegate, {0, 1});
});
// Specify that testing::kMul (#1) is supported when it actually isn't so that
// we choose to use the TestFlexDelegate that supports every node regardless
// whether it's actually supported or not.
ApplyFlexDelegate(std::unique_ptr<FlexDelegate>(new TestFlexDelegate()));
ASSERT_NE(interpreter_->AllocateTensors(), kTfLiteOk);
ASSERT_THAT(error_reporter().error_messages(),
ContainsRegex("Invalid NodeDef in Flex op"));
ContainsRegex("Cannot convert empty data into a valid NodeDef"));
}
TEST_F(KernelTest, MixedGraph) {
@ -207,9 +155,7 @@ TEST_F(KernelTest, MixedGraph) {
AddTfOp(testing::kAdd, {2, 5}, {7});
AddTfLiteMulOp({6, 7}, {8});
ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) {
return GenericPrepare(context, delegate, {0, 1, 2, 3});
});
ApplyFlexDelegate();
SetShape(0, {2, 2, 1});
SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
@ -251,14 +197,7 @@ TEST_F(KernelTest, SplitGraph) {
// The two branches added together:
AddTfOp(testing::kAdd, {9, 16}, {17}); // => 16
ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) {
// All ops but #3 are TF ops, handled by the delegate. However, because #4
// depends on the non-TF op, two subgraphs are necessary:
// TF subgraph 1: 0, 1, 2, 6, 7, 8, 9
// TF Lite Op: 3
// TF subgraph 2: 4, 5, 10
return GenericPrepare(context, delegate, {0, 1, 2, 4, 5, 6, 7, 8, 9, 10});
});
ApplyFlexDelegate();
SetShape(0, {2, 2, 2, 1});
SetValues(0, a);
@ -291,9 +230,8 @@ class MultipleSubgraphsTest : public KernelTest {
public:
static constexpr int kInput = 0;
void PrepareInterpreter(PrepareFunction prepare,
const std::vector<float>& input) {
ConfigureDelegate(prepare);
void PrepareInterpreter(const std::vector<float>& input) {
ApplyFlexDelegate();
SetShape(kOnes, {3});
SetValues(kOnes, {1.0f, 1.0f, 1.0f});
@ -336,7 +274,7 @@ TEST_F(MultipleSubgraphsTest, ForwardabilityIsLocal) {
AddTfLiteMulOp({10, 7}, {12});
auto input = {3.0f, 4.0f, 5.0f};
PrepareInterpreter(GetPrepareFunction<__LINE__>(), input);
PrepareInterpreter(input);
ASSERT_TRUE(Invoke());
ASSERT_THAT(GetValues(12), ElementsAreArray(Apply(input, [](float in) {
@ -371,7 +309,7 @@ TEST_F(MultipleSubgraphsTest, DoNotRemoveInputTensors) {
AddTfLiteMulOp({10, 7}, {12});
auto input = {3.0f, 4.0f, 5.0f};
PrepareInterpreter(GetPrepareFunction<__LINE__>(), input);
PrepareInterpreter(input);
ASSERT_TRUE(Invoke());
ASSERT_THAT(GetValues(12), ElementsAreArray(Apply(input, [](float in) {
@ -405,7 +343,7 @@ TEST_F(MultipleSubgraphsTest, DoNotForwardInputTensors) {
AddTfLiteMulOp({10, 7}, {12});
auto input = {3.0f, 4.0f, 5.0f};
PrepareInterpreter(GetPrepareFunction<__LINE__>(), input);
PrepareInterpreter(input);
ASSERT_TRUE(Invoke());
ASSERT_THAT(GetValues(12), ElementsAreArray(Apply(input, [](float in) {
@ -413,7 +351,7 @@ TEST_F(MultipleSubgraphsTest, DoNotForwardInputTensors) {
})));
}
} // namespace
} // namespace testing
} // namespace flex
} // namespace tflite

View File

@ -32,6 +32,7 @@ TfLiteRegistration GetDelegateKernelRegistration(
kernel_registration.profiling_string = nullptr;
kernel_registration.builtin_code = kTfLiteBuiltinDelegate;
kernel_registration.custom_name = delegate->name();
kernel_registration.version = 1;
kernel_registration.free = [](TfLiteContext* context, void* buffer) -> void {
delete reinterpret_cast<SimpleDelegateKernelInterface*>(buffer);
};
@ -77,6 +78,7 @@ TfLiteStatus DelegatePrepare(TfLiteContext* context,
TfLiteDelegate* base_delegate) {
auto* delegate =
reinterpret_cast<SimpleDelegateInterface*>(base_delegate->data_);
TF_LITE_ENSURE_STATUS(delegate->Initialize(context));
delegates::IsNodeSupportedFn node_supported_fn =
[=](TfLiteContext* context, TfLiteNode* node,
TfLiteRegistration* registration,
@ -125,5 +127,4 @@ void TfLiteDelegateFactory::DeleteSimpleDelegate(TfLiteDelegate* delegate) {
delete simple_delegate;
delete delegate;
}
} // namespace tflite

View File

@ -20,8 +20,12 @@ limitations under the License.
// this interface to build/prepare/invoke the delegated subgraph.
// - SimpleDelegateInterface:
// This class wraps TFLiteDelegate and users need to implement the interface and
// then Call GetFinalizedDelegate() to get TfLiteDelegate* that can be passed to
// ModifyGraphWithDelegate.
// then call TfLiteDelegateFactory::CreateSimpleDelegate(...) to get
// TfLiteDelegate* that can be passed to ModifyGraphWithDelegate and free it via
// TfLiteDelegateFactory::DeleteSimpleDelegate(...).
// or call TfLiteDelegateFactory::Create(...) to get a std::unique_ptr
// TfLiteDelegate that can also be passed to ModifyGraphWithDelegate, in which
// case TfLite interpereter takes the memory ownership of the delegate.
#ifndef TENSORFLOW_LITE_DELEGATES_UTILS_SIMPLE_DELEGATE_H_
#define TENSORFLOW_LITE_DELEGATES_UTILS_SIMPLE_DELEGATE_H_
@ -31,6 +35,9 @@ limitations under the License.
namespace tflite {
using TfLiteDelegateUniquePtr =
std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>;
// Users should inherit from this class and implement the interface below.
// Each instance represents a single part of the graph (subgraph).
class SimpleDelegateKernelInterface {
@ -49,6 +56,7 @@ class SimpleDelegateKernelInterface {
// Actual subgraph inference should happen on this call.
// Returns status, and signalling any errors.
// TODO(b/157882025): change this to Eval to be consistent w/ a TFLite kernel.
virtual TfLiteStatus Invoke(TfLiteContext* context, TfLiteNode* node) = 0;
};
@ -58,6 +66,7 @@ class SimpleDelegateKernelInterface {
//
// Clients should implement the following methods:
// - IsNodeSupportedByDelegate
// - Initialize
// - name
// - CreateDelegateKernelInterface
class SimpleDelegateInterface {
@ -71,8 +80,14 @@ class SimpleDelegateInterface {
const TfLiteNode* node,
TfLiteContext* context) const = 0;
// Initialize the delegate before finding and replacing TfLite nodes with
// delegate kernels, for example, retrieving some TFLite settings from
// 'context'.
virtual TfLiteStatus Initialize(TfLiteContext* context) = 0;
// Returns a name that identifies the delegate.
// This name is used for debugging/logging/profiling.
// TODO(b/157882025): change this to Name()
virtual const char* name() const = 0;
// Returns instance of an object that implements the interface
@ -84,13 +99,8 @@ class SimpleDelegateInterface {
CreateDelegateKernelInterface() = 0;
};
// Factory class that provides two static methods
// CreateSimpleDelegate
// DeleteSimpleDelegate
// Which should be used to construct TfLiteDelegate from
// Simple Delegate and delete TfLiteDelegate and SimpleDelegate give
// tfLiteDelegate* created from 'CreateSimpleDelegate' method.
// Users should use these methods to Create and Destroy the delegate.
// Factory class that provides static methods to deal with SimpleDelegate
// creation and deletion.
class TfLiteDelegateFactory {
public:
// Creates TfLiteDelegate from the provided SimpleDelegateInterface.
@ -99,9 +109,17 @@ class TfLiteDelegateFactory {
std::unique_ptr<SimpleDelegateInterface> simple_delegate);
// Deletes 'delegate' the passed pointer must be the one returned
// from GetFinalizedDelegate.
// from CreateSimpleDelegate.
// This function will destruct the SimpleDelegate object too.
static void DeleteSimpleDelegate(TfLiteDelegate* delegate);
// A convenient function wrapping the above two functions and returning a
// std::unique_ptr type for auto memory management.
inline static TfLiteDelegateUniquePtr Create(
std::unique_ptr<SimpleDelegateInterface> simple_delegate) {
return TfLiteDelegateUniquePtr(
CreateSimpleDelegate(std::move(simple_delegate)), DeleteSimpleDelegate);
}
};
} // namespace tflite

View File

@ -72,7 +72,12 @@ class TestSimpleDelegate : public SimpleDelegateInterface {
return options_.allowed_builtin_code == registration->builtin_code;
}
const char* name() const override { return "TestSimpleDelegate"; }
TfLiteStatus Initialize(TfLiteContext* context) override { return kTfLiteOk; }
const char* name() const override {
static constexpr char kName[] = "TestSimpleDelegate";
return kName;
}
std::unique_ptr<SimpleDelegateKernelInterface> CreateDelegateKernelInterface()
override {
@ -113,27 +118,24 @@ class TestDelegate : public ::testing::Test {
reg);
}
void TearDown() override {
interpreter_.reset();
TfLiteDelegateFactory::DeleteSimpleDelegate(delegate_);
}
void TearDown() override { interpreter_.reset(); }
protected:
std::unique_ptr<Interpreter> interpreter_;
TfLiteDelegate* delegate_ = nullptr;
};
TEST_F(TestDelegate, BasicDelegate) {
TestSimpleDelegateOptions options;
options.allowed_builtin_code = kTfLiteBuiltinAdd;
delegate_ = TfLiteDelegateFactory::CreateSimpleDelegate(
auto delegate = TfLiteDelegateFactory::Create(
std::make_unique<TestSimpleDelegate>(options));
interpreter_->ModifyGraphWithDelegate(delegate_);
interpreter_->ModifyGraphWithDelegate(std::move(delegate));
ASSERT_EQ(interpreter_->execution_plan().size(), 1);
int node = interpreter_->execution_plan()[0];
const auto* node_and_reg = interpreter_->node_and_registration(node);
EXPECT_EQ("TestSimpleDelegate", node_and_reg->second.custom_name);
EXPECT_STREQ("TestSimpleDelegate", node_and_reg->second.custom_name);
EXPECT_EQ(1, node_and_reg->second.version);
const TfLiteDelegateParams* params = static_cast<const TfLiteDelegateParams*>(
node_and_reg->first.builtin_data);
@ -154,9 +156,9 @@ TEST_F(TestDelegate, BasicDelegate) {
TEST_F(TestDelegate, NoNodesToDelegate) {
TestSimpleDelegateOptions options;
options.allowed_builtin_code = kTfLiteBuiltinSub;
delegate_ = TfLiteDelegateFactory::CreateSimpleDelegate(
auto delegate = TfLiteDelegateFactory::Create(
std::make_unique<TestSimpleDelegate>(options));
interpreter_->ModifyGraphWithDelegate(delegate_);
interpreter_->ModifyGraphWithDelegate(std::move(delegate));
ASSERT_EQ(interpreter_->execution_plan().size(), 3);
}
@ -165,19 +167,20 @@ TEST_F(TestDelegate, DelegateFailedPrepare) {
TestSimpleDelegateOptions options;
options.allowed_builtin_code = kTfLiteBuiltinAdd;
options.error_during_prepare = true;
delegate_ = TfLiteDelegateFactory::CreateSimpleDelegate(
auto delegate = TfLiteDelegateFactory::Create(
std::make_unique<TestSimpleDelegate>(options));
ASSERT_EQ(kTfLiteDelegateError,
interpreter_->ModifyGraphWithDelegate(delegate_));
interpreter_->ModifyGraphWithDelegate(std::move(delegate)));
}
TEST_F(TestDelegate, DelegateFailedInvoke) {
TestSimpleDelegateOptions options;
options.allowed_builtin_code = kTfLiteBuiltinAdd;
options.error_during_invoke = true;
delegate_ = TfLiteDelegateFactory::CreateSimpleDelegate(
auto delegate = TfLiteDelegateFactory::Create(
std::make_unique<TestSimpleDelegate>(options));
ASSERT_EQ(kTfLiteOk, interpreter_->ModifyGraphWithDelegate(delegate_));
ASSERT_EQ(kTfLiteOk,
interpreter_->ModifyGraphWithDelegate(std::move(delegate)));
ASSERT_EQ(kTfLiteError, interpreter_->Invoke());
}
@ -185,10 +188,10 @@ TEST_F(TestDelegate, DelegateFailedInit) {
TestSimpleDelegateOptions options;
options.allowed_builtin_code = kTfLiteBuiltinAdd;
options.error_during_init = true;
delegate_ = TfLiteDelegateFactory::CreateSimpleDelegate(
auto delegate = TfLiteDelegateFactory::Create(
std::make_unique<TestSimpleDelegate>(options));
ASSERT_EQ(kTfLiteDelegateError,
interpreter_->ModifyGraphWithDelegate(delegate_));
interpreter_->ModifyGraphWithDelegate(std::move(delegate)));
}
} // namespace
} // namespace tflite

View File

@ -337,10 +337,7 @@ TfLiteDriver::TfLiteDriver(DelegateType delegate_type, bool reference_kernel)
break;
case DelegateType::kFlex:
#if !defined(__APPLE__)
delegate_ = Interpreter::TfLiteDelegatePtr(
FlexDelegate::Create().release(), [](TfLiteDelegate* delegate) {
delete static_cast<tflite::FlexDelegate*>(delegate);
});
delegate_ = FlexDelegate::Create();
#endif
break;
}

View File

@ -744,8 +744,9 @@ typedef struct TfLiteDelegate {
struct TfLiteDelegate* delegate);
// Copy the data from delegate buffer handle into raw memory of the given
// 'tensor'. This cannot be null. The delegate is allowed to allocate the raw
// bytes as long as it follows the rules for kTfLiteDynamic tensors.
// 'tensor'. Note that the delegate is allowed to allocate the raw bytes as
// long as it follows the rules for kTfLiteDynamic tensors, in which case this
// cannot be null.
TfLiteStatus (*CopyFromBufferHandle)(TfLiteContext* context,
struct TfLiteDelegate* delegate,
TfLiteBufferHandle buffer_handle,