Add new method InterpreterBuilder::AddDelegate.

This achieves the same effect as Interpreter::ModifyGraphWithDelegate,
but (like the C API's TfLiteInterpreterOptionsAddDelegate function)
this avoids the need for the user to modify the Interpreter after
it has already been constructed.

PiperOrigin-RevId: 352597994
Change-Id: Iff5c0dbb19c8ef90d4cb77d442b44c9efe328a80
This commit is contained in:
Fergus Henderson 2021-01-19 10:21:58 -08:00 committed by TensorFlower Gardener
parent 5b69bbfb41
commit 8db2d99b1a
6 changed files with 196 additions and 40 deletions

View File

@ -541,6 +541,29 @@ cc_test(
],
)
cc_library(
name = "interpreter_test_util",
testonly = True,
hdrs = ["interpreter_test_util.h"],
deps = [
":builtin_op_data",
":external_cpu_backend_context",
":framework",
":string_util",
":version",
"//tensorflow/lite/c:common",
"//tensorflow/lite/core/api",
"//tensorflow/lite/kernels:builtin_ops",
"//tensorflow/lite/kernels:cpu_backend_context",
"//tensorflow/lite/kernels:kernel_util",
"//tensorflow/lite/kernels/internal:compatibility",
"//tensorflow/lite/schema:schema_fbs",
"//tensorflow/lite/testing:util",
"//third_party/eigen3",
"@com_google_googletest//:gtest",
],
)
# Test main interpreter
cc_test(
name = "interpreter_test",
@ -557,6 +580,7 @@ cc_test(
":builtin_op_data",
":external_cpu_backend_context",
":framework",
":interpreter_test_util",
":string_util",
":util",
":version",
@ -622,6 +646,7 @@ cc_test(
],
deps = [
":framework",
":interpreter_test_util",
"//tensorflow/lite/core/api",
"//tensorflow/lite/kernels:builtin_ops",
"//tensorflow/lite/testing:util",

View File

@ -636,11 +636,17 @@ TfLiteStatus InterpreterBuilder::ApplyDelegates(Interpreter* interpreter,
int num_threads) {
// Apply Flex delegate if applicable.
if (has_flex_op_) {
if (auto flex_delegate = AcquireFlexDelegate()) {
return interpreter->ModifyGraphWithDelegate(std::move(flex_delegate));
if (Interpreter::TfLiteDelegatePtr flex_delegate = AcquireFlexDelegate()) {
TF_LITE_ENSURE_STATUS(interpreter->ModifyGraphWithDelegate(
// Transfers ownership of flex_delegate to the interpreter.
std::move(flex_delegate)));
}
}
for (TfLiteDelegate* delegate : delegates_) {
// Note that we DON'T transfer ownership of the delegate to the interpreter.
// (Doing that would cause problems if operator() was invoked twice.)
TF_LITE_ENSURE_STATUS(interpreter->ModifyGraphWithDelegate(delegate));
}
return kTfLiteOk;
}
@ -764,10 +770,19 @@ TfLiteStatus InterpreterBuilder::operator()(
op_resolver_.GetDelegates(num_threads);
}
if (ApplyDelegates(interpreter->get(), num_threads) != kTfLiteOk)
return cleanup_and_error();
TfLiteStatus status = ApplyDelegates(interpreter->get(), num_threads);
if (status != kTfLiteOk) {
interpreter->reset();
}
return status;
}
return kTfLiteOk;
void InterpreterBuilder::AddDelegate(TfLiteDelegate* delegate) {
if (delegate == nullptr) {
TF_LITE_REPORT_ERROR(error_reporter_, "Null delegate.");
} else {
delegates_.push_back(delegate);
}
}
} // namespace tflite

View File

@ -63,10 +63,20 @@ class InterpreterBuilder {
~InterpreterBuilder();
InterpreterBuilder(const InterpreterBuilder&) = delete;
InterpreterBuilder& operator=(const InterpreterBuilder&) = delete;
TfLiteStatus operator()(std::unique_ptr<Interpreter>* interpreter);
TfLiteStatus operator()(std::unique_ptr<Interpreter>* interpreter,
int num_threads);
/// Any delegates added with AddDelegate will be applied to the Interpreter
/// generated by operator(), in the order that they were added. The delegate
/// parameter passed to AddDelegate should be non-null, otherwise an error
/// will be reported, and the call to AddDelegate will have no other effect).
/// The lifetime of the delegate must be at least as long as the lifetime of
/// any Interpreter generated by this InterpreterBuilder.
/// WARNING: This is an experimental API and subject to change.
void AddDelegate(TfLiteDelegate* delegate);
private:
TfLiteStatus BuildLocalIndexToRegistrationMapping();
TfLiteStatus ParseNodes(
@ -90,6 +100,7 @@ class InterpreterBuilder {
const ::tflite::Model* model_;
const OpResolver& op_resolver_;
ErrorReporter* error_reporter_;
std::vector<TfLiteDelegate*> delegates_;
std::vector<const TfLiteRegistration*> flatbuffer_op_index_to_registration_;
std::vector<TfLiteRegistration> unresolved_custom_ops_;

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/lite/builtin_op_data.h"
#include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/external_cpu_backend_context.h"
#include "tensorflow/lite/interpreter_test_util.h"
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
#include "tensorflow/lite/kernels/cpu_backend_context.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
@ -37,45 +38,13 @@ limitations under the License.
namespace tflite {
// InterpreterTest is a friend of Interpreter, so it can access context_.
class InterpreterTest : public ::testing::Test {
public:
template <typename Delegate>
static TfLiteStatus ModifyGraphWithDelegate(
Interpreter* interpreter, std::unique_ptr<Delegate> delegate) {
return interpreter->ModifyGraphWithDelegate(std::move(delegate));
}
protected:
TfLiteContext* GetInterpreterContext() { return interpreter_.context_; }
std::vector<Interpreter::TfLiteDelegatePtr>*
mutable_lazy_delegate_providers() {
return &interpreter_.lazy_delegate_providers_;
}
bool HasDelegates() { return interpreter_.HasDelegates(); }
void BuildSignature(const std::string& method_name, const std::string& key,
const std::map<std::string, uint32_t>& inputs,
const std::map<std::string, uint32_t>& outputs) {
Interpreter::SignatureDef signature;
signature.inputs = inputs;
signature.outputs = outputs;
signature.method_name = method_name;
signature.signature_def_key = key;
interpreter_.SetSignatureDef({signature});
}
Interpreter interpreter_;
};
namespace ops {
namespace builtin {
TfLiteRegistration* Register_PADV2();
TfLiteRegistration* Register_NEG();
} // namespace builtin
} // namespace ops
namespace {
using ::testing::IsEmpty;

View File

@ -0,0 +1,68 @@
#ifndef TENSORFLOW_LITE_INTERPRETER_TEST_UTIL_H_
#define TENSORFLOW_LITE_INTERPRETER_TEST_UTIL_H_
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <stdint.h>
#include <map>
#include <memory>
#include <string>
#include <vector>
#include <gtest/gtest.h>
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/string_util.h"
namespace tflite {
// InterpreterTest is a friend of Interpreter, so it can access context_.
class InterpreterTest : public ::testing::Test {
public:
template <typename Delegate>
static TfLiteStatus ModifyGraphWithDelegate(
Interpreter* interpreter, std::unique_ptr<Delegate> delegate) {
return interpreter->ModifyGraphWithDelegate(std::move(delegate));
}
protected:
TfLiteContext* GetInterpreterContext() { return interpreter_.context_; }
std::vector<Interpreter::TfLiteDelegatePtr>*
mutable_lazy_delegate_providers() {
return &interpreter_.lazy_delegate_providers_;
}
bool HasDelegates() { return interpreter_.HasDelegates(); }
void BuildSignature(const std::string& method_name, const std::string& key,
const std::map<std::string, uint32_t>& inputs,
const std::map<std::string, uint32_t>& outputs) {
Interpreter::SignatureDef signature;
signature.inputs = inputs;
signature.outputs = outputs;
signature.method_name = method_name;
signature.signature_def_key = key;
interpreter_.SetSignatureDef({signature});
}
Interpreter interpreter_;
};
} // namespace tflite
#endif // TENSORFLOW_LITE_INTERPRETER_TEST_UTIL_H_

View File

@ -18,11 +18,13 @@ limitations under the License.
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <fstream>
#include <iostream>
#include <gtest/gtest.h>
#include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/interpreter_test_util.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/testing/util.h"
@ -110,7 +112,7 @@ TEST(BasicFlatBufferModel, TestBufferAlignment) {
}
// Make sure a model with nothing in it loads properly.
TEST(BasicFlatBufferModel, TestEmptyModelsAndNullDestination) {
TEST(BasicFlatBufferModel, TestEmptyModels) {
auto model = FlatBufferModel::BuildFromFile(
"tensorflow/lite/testdata/empty_model.bin");
ASSERT_TRUE(model);
@ -119,6 +121,13 @@ TEST(BasicFlatBufferModel, TestEmptyModelsAndNullDestination) {
ASSERT_EQ(InterpreterBuilder(*model, TrivialResolver())(&interpreter),
kTfLiteOk);
ASSERT_NE(interpreter, nullptr);
}
TEST(BasicFlatBufferModel, TestNullDestination) {
auto model = FlatBufferModel::BuildFromFile(
"tensorflow/lite/testdata/empty_model.bin");
ASSERT_TRUE(model);
// Test that building with null destination fails.
ASSERT_NE(InterpreterBuilder(*model, TrivialResolver())(nullptr), kTfLiteOk);
}
@ -482,6 +491,65 @@ TEST(BasicFlatBufferModel, TestHandleMalformedModelInvalidBuffer) {
ASSERT_NE(interpreter->Invoke(), kTfLiteOk);
}
TEST(TestAddDelegateOwnership, AddDelegateDoesNotTakeOwnership) {
class TestDelegate : public TfLiteDelegate {
public:
TestDelegate(bool* destroyed, bool* prepared)
: TfLiteDelegate(TfLiteDelegateCreate()),
destroyed_(destroyed),
prepared_(prepared) {
flags = kTfLiteDelegateFlagsNone;
Prepare = [](TfLiteContext*, TfLiteDelegate* delegate) -> TfLiteStatus {
*(static_cast<TestDelegate*>(delegate)->prepared_) = true;
return kTfLiteOk;
};
}
~TestDelegate() { *destroyed_ = true; }
private:
bool* destroyed_;
bool* prepared_;
};
// Construct a delegate with flags for indicating preparation/destruction.
bool destroyed = false;
bool prepared = false;
{
std::unique_ptr<TestDelegate> delegate(
new TestDelegate(&destroyed, &prepared));
{
// Load a model.
auto model = FlatBufferModel::BuildFromFile(
"tensorflow/lite/testdata/empty_model.bin");
ASSERT_TRUE(model);
// Now try to build it into an interpreter.
std::unique_ptr<Interpreter> interpreter;
InterpreterBuilder builder(*model, TrivialResolver());
builder.AddDelegate(delegate.get()); // Does not transfer ownership.
// Loop to check we can construct multiple interpreters from one builder.
for (int i = 0; i < 3; i++) {
prepared = false;
ASSERT_EQ(builder(&interpreter), kTfLiteOk);
ASSERT_NE(interpreter, nullptr);
// The delegate should be prepared as normal, and should be preserved.
EXPECT_TRUE(prepared);
EXPECT_FALSE(destroyed);
// Interpreter interaction should not impact the delegate's validity.
interpreter->AllocateTensors();
interpreter->Invoke();
EXPECT_FALSE(destroyed);
}
}
EXPECT_NE(delegate, nullptr);
EXPECT_FALSE(destroyed);
}
// Only after the delegate itself goes out of scope should the delegate be
// destroyed.
EXPECT_TRUE(destroyed);
}
// TODO(aselle): Add tests for serialization of builtin op data types.
// These tests will occur with the evaluation tests of individual operators,
// not here.