Roll forward API for automatic fallback on delegation failure.
PiperOrigin-RevId: 313353207 Change-Id: I0f7824ecc5421a179c10a6de4fc5192e9815abb7
This commit is contained in:
parent
9084090e8d
commit
7738c1818e
@ -376,7 +376,9 @@ cc_test(
|
|||||||
cc_test(
|
cc_test(
|
||||||
name = "interpreter_test",
|
name = "interpreter_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
srcs = ["interpreter_test.cc"],
|
srcs = [
|
||||||
|
"interpreter_test.cc",
|
||||||
|
],
|
||||||
features = ["-dynamic_link_test_srcs"], # see go/dynamic_link_test_srcs
|
features = ["-dynamic_link_test_srcs"], # see go/dynamic_link_test_srcs
|
||||||
tags = [
|
tags = [
|
||||||
"tflite_not_portable_ios", # TODO(b/117786830)
|
"tflite_not_portable_ios", # TODO(b/117786830)
|
||||||
|
@ -533,6 +533,11 @@ void Subgraph::SetCancellationFunction(void* data,
|
|||||||
check_cancelled_func_ = check_cancelled_func;
|
check_cancelled_func_ = check_cancelled_func;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool Subgraph::IsCancelled() {
|
||||||
|
return (check_cancelled_func_ != nullptr) &&
|
||||||
|
(*check_cancelled_func_)(cancellation_data_);
|
||||||
|
}
|
||||||
|
|
||||||
void Subgraph::ReserveNodes(int count) {
|
void Subgraph::ReserveNodes(int count) {
|
||||||
nodes_and_registration_.reserve(count);
|
nodes_and_registration_.reserve(count);
|
||||||
}
|
}
|
||||||
@ -1316,6 +1321,8 @@ TfLiteStatus Subgraph::RemoveAllDelegates() {
|
|||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool Subgraph::HasDelegates() { return !delegates_applied_.empty(); }
|
||||||
|
|
||||||
TfLiteStatus Subgraph::EnsureMemoryAllocations() {
|
TfLiteStatus Subgraph::EnsureMemoryAllocations() {
|
||||||
if (memory_planner_) {
|
if (memory_planner_) {
|
||||||
state_ = kStateUninvokable;
|
state_ = kStateUninvokable;
|
||||||
|
@ -553,6 +553,9 @@ class Subgraph {
|
|||||||
// afterwards.
|
// afterwards.
|
||||||
TfLiteStatus RemoveAllDelegates();
|
TfLiteStatus RemoveAllDelegates();
|
||||||
|
|
||||||
|
// Returns true if the subgraph has delegates applied.
|
||||||
|
bool HasDelegates();
|
||||||
|
|
||||||
// Cleanups up data reserved for the given node. Does not remove the {node,
|
// Cleanups up data reserved for the given node. Does not remove the {node,
|
||||||
// registration} pair from nodes_and_registrations_.
|
// registration} pair from nodes_and_registrations_.
|
||||||
void CleanupNode(int node_index);
|
void CleanupNode(int node_index);
|
||||||
@ -578,6 +581,9 @@ class Subgraph {
|
|||||||
// Ensures the memory required is planned and allocated.
|
// Ensures the memory required is planned and allocated.
|
||||||
TfLiteStatus EnsureMemoryAllocations();
|
TfLiteStatus EnsureMemoryAllocations();
|
||||||
|
|
||||||
|
// Returns true if cancellation function returns true.
|
||||||
|
bool IsCancelled();
|
||||||
|
|
||||||
// The state of the Interpreter.
|
// The state of the Interpreter.
|
||||||
enum State {
|
enum State {
|
||||||
// The interpreter isn't ready to be invoked.
|
// The interpreter isn't ready to be invoked.
|
||||||
|
@ -32,6 +32,16 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "interpreter_utils",
|
||||||
|
srcs = ["interpreter_utils.cc"],
|
||||||
|
hdrs = ["interpreter_utils.h"],
|
||||||
|
copts = tflite_copts(),
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/lite:framework",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_test(
|
cc_test(
|
||||||
name = "utils_test",
|
name = "utils_test",
|
||||||
srcs = ["utils_test.cc"],
|
srcs = ["utils_test.cc"],
|
||||||
@ -53,6 +63,7 @@ cc_test(
|
|||||||
"tflite_not_portable_ios", # TODO(b/117786830)
|
"tflite_not_portable_ios", # TODO(b/117786830)
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
|
":interpreter_utils",
|
||||||
"//tensorflow/lite:framework",
|
"//tensorflow/lite:framework",
|
||||||
"//tensorflow/lite:version",
|
"//tensorflow/lite:version",
|
||||||
"//tensorflow/lite/core/api",
|
"//tensorflow/lite/core/api",
|
||||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <gmock/gmock.h>
|
#include <gmock/gmock.h>
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
#include "tensorflow/lite/delegates/interpreter_utils.h"
|
||||||
#include "tensorflow/lite/interpreter.h"
|
#include "tensorflow/lite/interpreter.h"
|
||||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||||
@ -261,8 +262,10 @@ class TestDelegate : public ::testing::Test {
|
|||||||
for (int i = 0; i < num; i++) {
|
for (int i = 0; i < num; i++) {
|
||||||
out->data.f[i] = a0->data.f[i] + a1->data.f[i];
|
out->data.f[i] = a0->data.f[i] + a1->data.f[i];
|
||||||
}
|
}
|
||||||
// Make the data stale so that CopyFromBufferHandle can be invoked
|
if (out->buffer_handle != kTfLiteNullBufferHandle) {
|
||||||
out->data_is_stale = true;
|
// Make the data stale so that CopyFromBufferHandle can be invoked
|
||||||
|
out->data_is_stale = true;
|
||||||
|
}
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
};
|
};
|
||||||
if (fail_delegate_node_invoke_) {
|
if (fail_delegate_node_invoke_) {
|
||||||
@ -397,6 +400,34 @@ TEST_F(TestDelegate, DelegateNodeInvokeFailure) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(TestDelegate, DelegateNodeInvokeFailureFallback) {
|
||||||
|
delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate(
|
||||||
|
{0, 1, 2}, kTfLiteDelegateFlagsNone, false /**fail_node_prepare**/,
|
||||||
|
0 /**min_ops_per_subset**/, true /**fail_node_invoke**/));
|
||||||
|
ASSERT_EQ(
|
||||||
|
interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()),
|
||||||
|
kTfLiteOk);
|
||||||
|
// Delegation modified execution plan.
|
||||||
|
ASSERT_EQ(interpreter_->execution_plan().size(), 1);
|
||||||
|
|
||||||
|
std::vector<float> input = {1.0f, 2.0f, 3.0f};
|
||||||
|
std::vector<float> expected_output = {2.0f, 4.0f, 6.0f};
|
||||||
|
constexpr int kOutputTensorIndex = 3;
|
||||||
|
|
||||||
|
memcpy(interpreter_->typed_tensor<float>(0), input.data(), 3 * sizeof(float));
|
||||||
|
memcpy(interpreter_->typed_tensor<float>(1), input.data(), 3 * sizeof(float));
|
||||||
|
EXPECT_EQ(
|
||||||
|
delegates::InterpreterUtils::InvokeWithCPUFallback(interpreter_.get()),
|
||||||
|
kTfLiteDelegateError);
|
||||||
|
// Delegation removed, returning to original execution plan.
|
||||||
|
ASSERT_EQ(interpreter_->execution_plan().size(), 3);
|
||||||
|
// Check outputs.
|
||||||
|
TfLiteTensor* tensor = interpreter_->tensor(kOutputTensorIndex);
|
||||||
|
for (int i = 0; i < 3; ++i) {
|
||||||
|
EXPECT_EQ(tensor->data.f[i], expected_output[i]) << i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(TestDelegate, SecondDelegationPrepareFailure) {
|
TEST_F(TestDelegate, SecondDelegationPrepareFailure) {
|
||||||
// First delegate only supports nodes 1, 2. Gets applied successfully.
|
// First delegate only supports nodes 1, 2. Gets applied successfully.
|
||||||
// This delegate should support dynamic tensors, otherwise the second won't be
|
// This delegate should support dynamic tensors, otherwise the second won't be
|
||||||
@ -713,6 +744,44 @@ TEST_F(TestDelegate, TestResizeInputWithMultipleDelegates) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(TestDelegate, TestFallbackWithMultipleDelegates) {
|
||||||
|
// First delegate only supports node 0.
|
||||||
|
// This delegate should support dynamic tensors, otherwise the second won't be
|
||||||
|
// applied.
|
||||||
|
delegate_ = std::unique_ptr<SimpleDelegate>(
|
||||||
|
new SimpleDelegate({0}, kTfLiteDelegateFlagsAllowDynamicTensors));
|
||||||
|
// Second delegate supports nodes 1 & 2, and makes the graph immutable.
|
||||||
|
delegate2_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate(
|
||||||
|
{1, 2}, kTfLiteDelegateFlagsNone, false /**fail_node_prepare**/,
|
||||||
|
0 /**min_ops_per_subset**/, true /**fail_node_invoke**/));
|
||||||
|
// Pre-delegation execution plan should have three nodes.
|
||||||
|
ASSERT_EQ(interpreter_->execution_plan().size(), 3);
|
||||||
|
ASSERT_EQ(
|
||||||
|
interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()),
|
||||||
|
kTfLiteOk);
|
||||||
|
ASSERT_EQ(
|
||||||
|
interpreter_->ModifyGraphWithDelegate(delegate2_->get_tf_lite_delegate()),
|
||||||
|
kTfLiteOk);
|
||||||
|
// Should be two delegates nodes.
|
||||||
|
ASSERT_EQ(interpreter_->execution_plan().size(), 2);
|
||||||
|
|
||||||
|
std::vector<float> input = {1.0f, 2.0f, 3.0f};
|
||||||
|
std::vector<float> expected_output = {2.0f, 4.0f, 6.0f};
|
||||||
|
constexpr int kOutputTensorIndex = 2;
|
||||||
|
TfLiteTensor* tensor = interpreter_->tensor(kOutputTensorIndex);
|
||||||
|
|
||||||
|
memcpy(interpreter_->typed_tensor<float>(0), input.data(), 3 * sizeof(float));
|
||||||
|
memcpy(interpreter_->typed_tensor<float>(1), input.data(), 3 * sizeof(float));
|
||||||
|
EXPECT_EQ(
|
||||||
|
delegates::InterpreterUtils::InvokeWithCPUFallback(interpreter_.get()),
|
||||||
|
kTfLiteDelegateError);
|
||||||
|
// All delegates should be undone.
|
||||||
|
EXPECT_EQ(interpreter_->execution_plan().size(), 3);
|
||||||
|
for (int i = 0; i < 3; ++i) {
|
||||||
|
EXPECT_EQ(tensor->data.f[i], expected_output[i]) << i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(TestDelegate, ReleaseNonPersistentMemoryWithDelegates) {
|
TEST_F(TestDelegate, ReleaseNonPersistentMemoryWithDelegates) {
|
||||||
// First delegate only supports node 0.
|
// First delegate only supports node 0.
|
||||||
// This delegate should support dynamic tensors, otherwise the second won't be
|
// This delegate should support dynamic tensors, otherwise the second won't be
|
||||||
|
65
tensorflow/lite/delegates/interpreter_utils.cc
Normal file
65
tensorflow/lite/delegates/interpreter_utils.cc
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/interpreter_utils.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace delegates {
|
||||||
|
TfLiteStatus InterpreterUtils::InvokeWithCPUFallback(Interpreter* interpreter) {
|
||||||
|
TfLiteStatus status = interpreter->Invoke();
|
||||||
|
if (status == kTfLiteOk || interpreter->IsCancelled() ||
|
||||||
|
!interpreter->HasDelegates()) {
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
// Retry without delegation.
|
||||||
|
// TODO(b/138706191): retry only if error is due to delegation.
|
||||||
|
TF_LITE_REPORT_ERROR(
|
||||||
|
interpreter->error_reporter(),
|
||||||
|
"Invoke() failed in the presence of delegation. Retrying without.");
|
||||||
|
|
||||||
|
// Copy input data to a buffer.
|
||||||
|
// Input data is safe since Subgraph::PrepareOpsAndTensors() passes
|
||||||
|
// preserve_inputs=true to ArenaPlanner.
|
||||||
|
std::vector<char> buf;
|
||||||
|
size_t input_size = 0;
|
||||||
|
|
||||||
|
for (auto i : interpreter->inputs()) {
|
||||||
|
TF_LITE_ENSURE_STATUS(interpreter->EnsureTensorDataIsReadable(i));
|
||||||
|
TfLiteTensor* t = interpreter->tensor(i);
|
||||||
|
input_size += t->bytes;
|
||||||
|
}
|
||||||
|
buf.reserve(input_size);
|
||||||
|
for (auto i : interpreter->inputs()) {
|
||||||
|
TfLiteTensor* t = interpreter->tensor(i);
|
||||||
|
buf.insert(buf.end(), t->data.raw, t->data.raw + t->bytes);
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_LITE_ENSURE_STATUS(interpreter->RemoveAllDelegates());
|
||||||
|
|
||||||
|
// Copy inputs from buffer.
|
||||||
|
auto bufp = buf.begin();
|
||||||
|
for (auto i : interpreter->inputs()) {
|
||||||
|
TfLiteTensor* t = interpreter->tensor(i);
|
||||||
|
std::copy(bufp, bufp + t->bytes, t->data.raw);
|
||||||
|
bufp += t->bytes;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Invoke again.
|
||||||
|
TF_LITE_ENSURE_STATUS(interpreter->Invoke());
|
||||||
|
return kTfLiteDelegateError;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace delegates
|
||||||
|
} // namespace tflite
|
52
tensorflow/lite/delegates/interpreter_utils.h
Normal file
52
tensorflow/lite/delegates/interpreter_utils.h
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_DELEGATES_INTERPRETER_UTILS_H_
|
||||||
|
#define TENSORFLOW_LITE_DELEGATES_INTERPRETER_UTILS_H_
|
||||||
|
|
||||||
|
#include "tensorflow/lite/interpreter.h"
|
||||||
|
|
||||||
|
// Utility functions and classes for using delegates.
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace delegates {
|
||||||
|
#if !TFLITE_EXPERIMENTAL_RUNTIME_EAGER
|
||||||
|
class InterpreterUtils {
|
||||||
|
public:
|
||||||
|
/// Invokes an interpreter with automatic fallback from delegation to CPU.
|
||||||
|
///
|
||||||
|
/// If using the delegate fails, the delegate is automatically undone and an
|
||||||
|
/// attempt made to return the interpreter to an invokable state.
|
||||||
|
///
|
||||||
|
/// Allowing the fallback is suitable only if both of the following hold:
|
||||||
|
/// - The caller is known not to cache pointers to tensor data across Invoke()
|
||||||
|
/// calls.
|
||||||
|
/// - The model is not stateful (no variables, no LSTMs) or the state isn't
|
||||||
|
/// needed between batches.
|
||||||
|
///
|
||||||
|
/// Returns one of the following three status codes:
|
||||||
|
/// 1. kTfLiteOk: Success. Output is valid.
|
||||||
|
/// 2. kTfLiteDelegateError: Delegate error but fallback succeeded. Output is
|
||||||
|
/// valid.
|
||||||
|
/// NOTE: This undoes all delegates previously applied to the Interpreter.
|
||||||
|
/// 3. kTfLiteError: Unexpected/runtime failure. Output is invalid.
|
||||||
|
/// WARNING: This is an experimental API and subject to change.
|
||||||
|
static TfLiteStatus InvokeWithCPUFallback(Interpreter* interpreter);
|
||||||
|
};
|
||||||
|
#endif // !TFLITE_EXPERIMENTAL_RUNTIME_EAGER
|
||||||
|
} // namespace delegates
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_DELEGATES_INTERPRETER_UTILS_H_
|
@ -16,6 +16,8 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_LITE_DELEGATES_UTILS_H_
|
#ifndef TENSORFLOW_LITE_DELEGATES_UTILS_H_
|
||||||
#define TENSORFLOW_LITE_DELEGATES_UTILS_H_
|
#define TENSORFLOW_LITE_DELEGATES_UTILS_H_
|
||||||
|
|
||||||
|
// Utility functions and classes for implementing delegates.
|
||||||
|
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <set>
|
#include <set>
|
||||||
|
@ -310,6 +310,8 @@ void Interpreter::SetCancellationFunction(void* data,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool Interpreter::IsCancelled() { return primary_subgraph().IsCancelled(); }
|
||||||
|
|
||||||
TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate) {
|
TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate) {
|
||||||
TfLiteStatus status = kTfLiteOk;
|
TfLiteStatus status = kTfLiteOk;
|
||||||
for (auto& subgraph : subgraphs_) {
|
for (auto& subgraph : subgraphs_) {
|
||||||
@ -340,6 +342,8 @@ TfLiteStatus Interpreter::RemoveAllDelegates() {
|
|||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool Interpreter::HasDelegates() { return primary_subgraph().HasDelegates(); }
|
||||||
|
|
||||||
TfLiteStatus Interpreter::SetBufferHandle(int tensor_index,
|
TfLiteStatus Interpreter::SetBufferHandle(int tensor_index,
|
||||||
TfLiteBufferHandle buffer_handle,
|
TfLiteBufferHandle buffer_handle,
|
||||||
TfLiteDelegate* delegate) {
|
TfLiteDelegate* delegate) {
|
||||||
|
@ -42,6 +42,9 @@ namespace tflite {
|
|||||||
|
|
||||||
class InterpreterTest;
|
class InterpreterTest;
|
||||||
class TestDelegate;
|
class TestDelegate;
|
||||||
|
namespace delegates {
|
||||||
|
class InterpreterUtils; // Class for friend declarations.
|
||||||
|
} // namespace delegates
|
||||||
|
|
||||||
namespace impl {
|
namespace impl {
|
||||||
|
|
||||||
@ -529,6 +532,7 @@ class Interpreter {
|
|||||||
friend class InterpreterBuilder;
|
friend class InterpreterBuilder;
|
||||||
friend class tflite::InterpreterTest;
|
friend class tflite::InterpreterTest;
|
||||||
friend class tflite::TestDelegate;
|
friend class tflite::TestDelegate;
|
||||||
|
friend class tflite::delegates::InterpreterUtils;
|
||||||
|
|
||||||
/// Set the value of an external context.
|
/// Set the value of an external context.
|
||||||
static void SetExternalContext(struct TfLiteContext* context,
|
static void SetExternalContext(struct TfLiteContext* context,
|
||||||
@ -542,6 +546,15 @@ class Interpreter {
|
|||||||
// afterwards.
|
// afterwards.
|
||||||
TfLiteStatus RemoveAllDelegates();
|
TfLiteStatus RemoveAllDelegates();
|
||||||
|
|
||||||
|
// Returns true if delegates have been applied.
|
||||||
|
bool HasDelegates();
|
||||||
|
|
||||||
|
// Returns true if cancellation function returns true.
|
||||||
|
bool IsCancelled();
|
||||||
|
|
||||||
|
// Get the error reporter associated with this interpreter.
|
||||||
|
ErrorReporter* error_reporter() { return error_reporter_; }
|
||||||
|
|
||||||
// A pure C data structure used to communicate with the pure C plugin
|
// A pure C data structure used to communicate with the pure C plugin
|
||||||
// interface. To avoid copying tensor metadata, this is also the definitive
|
// interface. To avoid copying tensor metadata, this is also the definitive
|
||||||
// structure to store tensors.
|
// structure to store tensors.
|
||||||
|
Loading…
Reference in New Issue
Block a user