Add new method InterpreterBuilder::AddDelegate.
This achieves the same effect as Interpreter::ModifyGraphWithDelegate, but (like the C API's TfLiteInterpreterOptionsAddDelegate function) this avoids the need for the user to modify the Interpreter after it has already been constructed. PiperOrigin-RevId: 352597994 Change-Id: Iff5c0dbb19c8ef90d4cb77d442b44c9efe328a80
This commit is contained in:
parent
5b69bbfb41
commit
8db2d99b1a
@ -541,6 +541,29 @@ cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "interpreter_test_util",
|
||||
testonly = True,
|
||||
hdrs = ["interpreter_test_util.h"],
|
||||
deps = [
|
||||
":builtin_op_data",
|
||||
":external_cpu_backend_context",
|
||||
":framework",
|
||||
":string_util",
|
||||
":version",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/core/api",
|
||||
"//tensorflow/lite/kernels:builtin_ops",
|
||||
"//tensorflow/lite/kernels:cpu_backend_context",
|
||||
"//tensorflow/lite/kernels:kernel_util",
|
||||
"//tensorflow/lite/kernels/internal:compatibility",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"//tensorflow/lite/testing:util",
|
||||
"//third_party/eigen3",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
||||
# Test main interpreter
|
||||
cc_test(
|
||||
name = "interpreter_test",
|
||||
@ -557,6 +580,7 @@ cc_test(
|
||||
":builtin_op_data",
|
||||
":external_cpu_backend_context",
|
||||
":framework",
|
||||
":interpreter_test_util",
|
||||
":string_util",
|
||||
":util",
|
||||
":version",
|
||||
@ -622,6 +646,7 @@ cc_test(
|
||||
],
|
||||
deps = [
|
||||
":framework",
|
||||
":interpreter_test_util",
|
||||
"//tensorflow/lite/core/api",
|
||||
"//tensorflow/lite/kernels:builtin_ops",
|
||||
"//tensorflow/lite/testing:util",
|
||||
|
@ -636,11 +636,17 @@ TfLiteStatus InterpreterBuilder::ApplyDelegates(Interpreter* interpreter,
|
||||
int num_threads) {
|
||||
// Apply Flex delegate if applicable.
|
||||
if (has_flex_op_) {
|
||||
if (auto flex_delegate = AcquireFlexDelegate()) {
|
||||
return interpreter->ModifyGraphWithDelegate(std::move(flex_delegate));
|
||||
if (Interpreter::TfLiteDelegatePtr flex_delegate = AcquireFlexDelegate()) {
|
||||
TF_LITE_ENSURE_STATUS(interpreter->ModifyGraphWithDelegate(
|
||||
// Transfers ownership of flex_delegate to the interpreter.
|
||||
std::move(flex_delegate)));
|
||||
}
|
||||
}
|
||||
|
||||
for (TfLiteDelegate* delegate : delegates_) {
|
||||
// Note that we DON'T transfer ownership of the delegate to the interpreter.
|
||||
// (Doing that would cause problems if operator() was invoked twice.)
|
||||
TF_LITE_ENSURE_STATUS(interpreter->ModifyGraphWithDelegate(delegate));
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
@ -764,10 +770,19 @@ TfLiteStatus InterpreterBuilder::operator()(
|
||||
op_resolver_.GetDelegates(num_threads);
|
||||
}
|
||||
|
||||
if (ApplyDelegates(interpreter->get(), num_threads) != kTfLiteOk)
|
||||
return cleanup_and_error();
|
||||
TfLiteStatus status = ApplyDelegates(interpreter->get(), num_threads);
|
||||
if (status != kTfLiteOk) {
|
||||
interpreter->reset();
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
return kTfLiteOk;
|
||||
void InterpreterBuilder::AddDelegate(TfLiteDelegate* delegate) {
|
||||
if (delegate == nullptr) {
|
||||
TF_LITE_REPORT_ERROR(error_reporter_, "Null delegate.");
|
||||
} else {
|
||||
delegates_.push_back(delegate);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
@ -63,10 +63,20 @@ class InterpreterBuilder {
|
||||
~InterpreterBuilder();
|
||||
InterpreterBuilder(const InterpreterBuilder&) = delete;
|
||||
InterpreterBuilder& operator=(const InterpreterBuilder&) = delete;
|
||||
|
||||
TfLiteStatus operator()(std::unique_ptr<Interpreter>* interpreter);
|
||||
TfLiteStatus operator()(std::unique_ptr<Interpreter>* interpreter,
|
||||
int num_threads);
|
||||
|
||||
/// Any delegates added with AddDelegate will be applied to the Interpreter
|
||||
/// generated by operator(), in the order that they were added. The delegate
|
||||
/// parameter passed to AddDelegate should be non-null, otherwise an error
|
||||
/// will be reported, and the call to AddDelegate will have no other effect).
|
||||
/// The lifetime of the delegate must be at least as long as the lifetime of
|
||||
/// any Interpreter generated by this InterpreterBuilder.
|
||||
/// WARNING: This is an experimental API and subject to change.
|
||||
void AddDelegate(TfLiteDelegate* delegate);
|
||||
|
||||
private:
|
||||
TfLiteStatus BuildLocalIndexToRegistrationMapping();
|
||||
TfLiteStatus ParseNodes(
|
||||
@ -90,6 +100,7 @@ class InterpreterBuilder {
|
||||
const ::tflite::Model* model_;
|
||||
const OpResolver& op_resolver_;
|
||||
ErrorReporter* error_reporter_;
|
||||
std::vector<TfLiteDelegate*> delegates_;
|
||||
|
||||
std::vector<const TfLiteRegistration*> flatbuffer_op_index_to_registration_;
|
||||
std::vector<TfLiteRegistration> unresolved_custom_ops_;
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/builtin_op_data.h"
|
||||
#include "tensorflow/lite/core/api/error_reporter.h"
|
||||
#include "tensorflow/lite/external_cpu_backend_context.h"
|
||||
#include "tensorflow/lite/interpreter_test_util.h"
|
||||
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
|
||||
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||
@ -37,45 +38,13 @@ limitations under the License.
|
||||
|
||||
namespace tflite {
|
||||
|
||||
// InterpreterTest is a friend of Interpreter, so it can access context_.
|
||||
class InterpreterTest : public ::testing::Test {
|
||||
public:
|
||||
template <typename Delegate>
|
||||
static TfLiteStatus ModifyGraphWithDelegate(
|
||||
Interpreter* interpreter, std::unique_ptr<Delegate> delegate) {
|
||||
return interpreter->ModifyGraphWithDelegate(std::move(delegate));
|
||||
}
|
||||
|
||||
protected:
|
||||
TfLiteContext* GetInterpreterContext() { return interpreter_.context_; }
|
||||
|
||||
std::vector<Interpreter::TfLiteDelegatePtr>*
|
||||
mutable_lazy_delegate_providers() {
|
||||
return &interpreter_.lazy_delegate_providers_;
|
||||
}
|
||||
|
||||
bool HasDelegates() { return interpreter_.HasDelegates(); }
|
||||
|
||||
void BuildSignature(const std::string& method_name, const std::string& key,
|
||||
const std::map<std::string, uint32_t>& inputs,
|
||||
const std::map<std::string, uint32_t>& outputs) {
|
||||
Interpreter::SignatureDef signature;
|
||||
signature.inputs = inputs;
|
||||
signature.outputs = outputs;
|
||||
signature.method_name = method_name;
|
||||
signature.signature_def_key = key;
|
||||
interpreter_.SetSignatureDef({signature});
|
||||
}
|
||||
|
||||
Interpreter interpreter_;
|
||||
};
|
||||
|
||||
namespace ops {
|
||||
namespace builtin {
|
||||
TfLiteRegistration* Register_PADV2();
|
||||
TfLiteRegistration* Register_NEG();
|
||||
} // namespace builtin
|
||||
} // namespace ops
|
||||
|
||||
namespace {
|
||||
|
||||
using ::testing::IsEmpty;
|
||||
|
68
tensorflow/lite/interpreter_test_util.h
Normal file
68
tensorflow/lite/interpreter_test_util.h
Normal file
@ -0,0 +1,68 @@
|
||||
#ifndef TENSORFLOW_LITE_INTERPRETER_TEST_UTIL_H_
|
||||
#define TENSORFLOW_LITE_INTERPRETER_TEST_UTIL_H_
|
||||
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||
#include "tensorflow/lite/string_util.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
// InterpreterTest is a friend of Interpreter, so it can access context_.
|
||||
class InterpreterTest : public ::testing::Test {
|
||||
public:
|
||||
template <typename Delegate>
|
||||
static TfLiteStatus ModifyGraphWithDelegate(
|
||||
Interpreter* interpreter, std::unique_ptr<Delegate> delegate) {
|
||||
return interpreter->ModifyGraphWithDelegate(std::move(delegate));
|
||||
}
|
||||
|
||||
protected:
|
||||
TfLiteContext* GetInterpreterContext() { return interpreter_.context_; }
|
||||
|
||||
std::vector<Interpreter::TfLiteDelegatePtr>*
|
||||
mutable_lazy_delegate_providers() {
|
||||
return &interpreter_.lazy_delegate_providers_;
|
||||
}
|
||||
|
||||
bool HasDelegates() { return interpreter_.HasDelegates(); }
|
||||
|
||||
void BuildSignature(const std::string& method_name, const std::string& key,
|
||||
const std::map<std::string, uint32_t>& inputs,
|
||||
const std::map<std::string, uint32_t>& outputs) {
|
||||
Interpreter::SignatureDef signature;
|
||||
signature.inputs = inputs;
|
||||
signature.outputs = outputs;
|
||||
signature.method_name = method_name;
|
||||
signature.signature_def_key = key;
|
||||
interpreter_.SetSignatureDef({signature});
|
||||
}
|
||||
|
||||
Interpreter interpreter_;
|
||||
};
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_INTERPRETER_TEST_UTIL_H_
|
@ -18,11 +18,13 @@ limitations under the License.
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/core/api/error_reporter.h"
|
||||
#include "tensorflow/lite/interpreter_test_util.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/testing/util.h"
|
||||
|
||||
@ -110,7 +112,7 @@ TEST(BasicFlatBufferModel, TestBufferAlignment) {
|
||||
}
|
||||
|
||||
// Make sure a model with nothing in it loads properly.
|
||||
TEST(BasicFlatBufferModel, TestEmptyModelsAndNullDestination) {
|
||||
TEST(BasicFlatBufferModel, TestEmptyModels) {
|
||||
auto model = FlatBufferModel::BuildFromFile(
|
||||
"tensorflow/lite/testdata/empty_model.bin");
|
||||
ASSERT_TRUE(model);
|
||||
@ -119,6 +121,13 @@ TEST(BasicFlatBufferModel, TestEmptyModelsAndNullDestination) {
|
||||
ASSERT_EQ(InterpreterBuilder(*model, TrivialResolver())(&interpreter),
|
||||
kTfLiteOk);
|
||||
ASSERT_NE(interpreter, nullptr);
|
||||
}
|
||||
|
||||
TEST(BasicFlatBufferModel, TestNullDestination) {
|
||||
auto model = FlatBufferModel::BuildFromFile(
|
||||
"tensorflow/lite/testdata/empty_model.bin");
|
||||
ASSERT_TRUE(model);
|
||||
// Test that building with null destination fails.
|
||||
ASSERT_NE(InterpreterBuilder(*model, TrivialResolver())(nullptr), kTfLiteOk);
|
||||
}
|
||||
|
||||
@ -482,6 +491,65 @@ TEST(BasicFlatBufferModel, TestHandleMalformedModelInvalidBuffer) {
|
||||
ASSERT_NE(interpreter->Invoke(), kTfLiteOk);
|
||||
}
|
||||
|
||||
TEST(TestAddDelegateOwnership, AddDelegateDoesNotTakeOwnership) {
|
||||
class TestDelegate : public TfLiteDelegate {
|
||||
public:
|
||||
TestDelegate(bool* destroyed, bool* prepared)
|
||||
: TfLiteDelegate(TfLiteDelegateCreate()),
|
||||
destroyed_(destroyed),
|
||||
prepared_(prepared) {
|
||||
flags = kTfLiteDelegateFlagsNone;
|
||||
Prepare = [](TfLiteContext*, TfLiteDelegate* delegate) -> TfLiteStatus {
|
||||
*(static_cast<TestDelegate*>(delegate)->prepared_) = true;
|
||||
return kTfLiteOk;
|
||||
};
|
||||
}
|
||||
~TestDelegate() { *destroyed_ = true; }
|
||||
|
||||
private:
|
||||
bool* destroyed_;
|
||||
bool* prepared_;
|
||||
};
|
||||
|
||||
// Construct a delegate with flags for indicating preparation/destruction.
|
||||
bool destroyed = false;
|
||||
bool prepared = false;
|
||||
{
|
||||
std::unique_ptr<TestDelegate> delegate(
|
||||
new TestDelegate(&destroyed, &prepared));
|
||||
{
|
||||
// Load a model.
|
||||
auto model = FlatBufferModel::BuildFromFile(
|
||||
"tensorflow/lite/testdata/empty_model.bin");
|
||||
ASSERT_TRUE(model);
|
||||
// Now try to build it into an interpreter.
|
||||
std::unique_ptr<Interpreter> interpreter;
|
||||
InterpreterBuilder builder(*model, TrivialResolver());
|
||||
builder.AddDelegate(delegate.get()); // Does not transfer ownership.
|
||||
// Loop to check we can construct multiple interpreters from one builder.
|
||||
for (int i = 0; i < 3; i++) {
|
||||
prepared = false;
|
||||
ASSERT_EQ(builder(&interpreter), kTfLiteOk);
|
||||
ASSERT_NE(interpreter, nullptr);
|
||||
|
||||
// The delegate should be prepared as normal, and should be preserved.
|
||||
EXPECT_TRUE(prepared);
|
||||
EXPECT_FALSE(destroyed);
|
||||
|
||||
// Interpreter interaction should not impact the delegate's validity.
|
||||
interpreter->AllocateTensors();
|
||||
interpreter->Invoke();
|
||||
EXPECT_FALSE(destroyed);
|
||||
}
|
||||
}
|
||||
EXPECT_NE(delegate, nullptr);
|
||||
EXPECT_FALSE(destroyed);
|
||||
}
|
||||
// Only after the delegate itself goes out of scope should the delegate be
|
||||
// destroyed.
|
||||
EXPECT_TRUE(destroyed);
|
||||
}
|
||||
|
||||
// TODO(aselle): Add tests for serialization of builtin op data types.
|
||||
// These tests will occur with the evaluation tests of individual operators,
|
||||
// not here.
|
||||
|
Loading…
x
Reference in New Issue
Block a user