Roll forward API for automatic fallback on delegation failure.
PiperOrigin-RevId: 313353207 Change-Id: I0f7824ecc5421a179c10a6de4fc5192e9815abb7
This commit is contained in:
parent
9084090e8d
commit
7738c1818e
@ -376,7 +376,9 @@ cc_test(
|
||||
cc_test(
|
||||
name = "interpreter_test",
|
||||
size = "small",
|
||||
srcs = ["interpreter_test.cc"],
|
||||
srcs = [
|
||||
"interpreter_test.cc",
|
||||
],
|
||||
features = ["-dynamic_link_test_srcs"], # see go/dynamic_link_test_srcs
|
||||
tags = [
|
||||
"tflite_not_portable_ios", # TODO(b/117786830)
|
||||
|
@ -533,6 +533,11 @@ void Subgraph::SetCancellationFunction(void* data,
|
||||
check_cancelled_func_ = check_cancelled_func;
|
||||
}
|
||||
|
||||
bool Subgraph::IsCancelled() {
|
||||
return (check_cancelled_func_ != nullptr) &&
|
||||
(*check_cancelled_func_)(cancellation_data_);
|
||||
}
|
||||
|
||||
void Subgraph::ReserveNodes(int count) {
|
||||
nodes_and_registration_.reserve(count);
|
||||
}
|
||||
@ -1316,6 +1321,8 @@ TfLiteStatus Subgraph::RemoveAllDelegates() {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
bool Subgraph::HasDelegates() { return !delegates_applied_.empty(); }
|
||||
|
||||
TfLiteStatus Subgraph::EnsureMemoryAllocations() {
|
||||
if (memory_planner_) {
|
||||
state_ = kStateUninvokable;
|
||||
|
@ -553,6 +553,9 @@ class Subgraph {
|
||||
// afterwards.
|
||||
TfLiteStatus RemoveAllDelegates();
|
||||
|
||||
// Returns true if the subgraph has delegates applied.
|
||||
bool HasDelegates();
|
||||
|
||||
// Cleanups up data reserved for the given node. Does not remove the {node,
|
||||
// registration} pair from nodes_and_registrations_.
|
||||
void CleanupNode(int node_index);
|
||||
@ -578,6 +581,9 @@ class Subgraph {
|
||||
// Ensures the memory required is planned and allocated.
|
||||
TfLiteStatus EnsureMemoryAllocations();
|
||||
|
||||
// Returns true if cancellation function returns true.
|
||||
bool IsCancelled();
|
||||
|
||||
// The state of the Interpreter.
|
||||
enum State {
|
||||
// The interpreter isn't ready to be invoked.
|
||||
|
@ -32,6 +32,16 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "interpreter_utils",
|
||||
srcs = ["interpreter_utils.cc"],
|
||||
hdrs = ["interpreter_utils.h"],
|
||||
copts = tflite_copts(),
|
||||
deps = [
|
||||
"//tensorflow/lite:framework",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "utils_test",
|
||||
srcs = ["utils_test.cc"],
|
||||
@ -53,6 +63,7 @@ cc_test(
|
||||
"tflite_not_portable_ios", # TODO(b/117786830)
|
||||
],
|
||||
deps = [
|
||||
":interpreter_utils",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite:version",
|
||||
"//tensorflow/lite/core/api",
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/delegates/interpreter_utils.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
@ -261,8 +262,10 @@ class TestDelegate : public ::testing::Test {
|
||||
for (int i = 0; i < num; i++) {
|
||||
out->data.f[i] = a0->data.f[i] + a1->data.f[i];
|
||||
}
|
||||
// Make the data stale so that CopyFromBufferHandle can be invoked
|
||||
out->data_is_stale = true;
|
||||
if (out->buffer_handle != kTfLiteNullBufferHandle) {
|
||||
// Make the data stale so that CopyFromBufferHandle can be invoked
|
||||
out->data_is_stale = true;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
};
|
||||
if (fail_delegate_node_invoke_) {
|
||||
@ -397,6 +400,34 @@ TEST_F(TestDelegate, DelegateNodeInvokeFailure) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TestDelegate, DelegateNodeInvokeFailureFallback) {
|
||||
delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate(
|
||||
{0, 1, 2}, kTfLiteDelegateFlagsNone, false /**fail_node_prepare**/,
|
||||
0 /**min_ops_per_subset**/, true /**fail_node_invoke**/));
|
||||
ASSERT_EQ(
|
||||
interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()),
|
||||
kTfLiteOk);
|
||||
// Delegation modified execution plan.
|
||||
ASSERT_EQ(interpreter_->execution_plan().size(), 1);
|
||||
|
||||
std::vector<float> input = {1.0f, 2.0f, 3.0f};
|
||||
std::vector<float> expected_output = {2.0f, 4.0f, 6.0f};
|
||||
constexpr int kOutputTensorIndex = 3;
|
||||
|
||||
memcpy(interpreter_->typed_tensor<float>(0), input.data(), 3 * sizeof(float));
|
||||
memcpy(interpreter_->typed_tensor<float>(1), input.data(), 3 * sizeof(float));
|
||||
EXPECT_EQ(
|
||||
delegates::InterpreterUtils::InvokeWithCPUFallback(interpreter_.get()),
|
||||
kTfLiteDelegateError);
|
||||
// Delegation removed, returning to original execution plan.
|
||||
ASSERT_EQ(interpreter_->execution_plan().size(), 3);
|
||||
// Check outputs.
|
||||
TfLiteTensor* tensor = interpreter_->tensor(kOutputTensorIndex);
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
EXPECT_EQ(tensor->data.f[i], expected_output[i]) << i;
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TestDelegate, SecondDelegationPrepareFailure) {
|
||||
// First delegate only supports nodes 1, 2. Gets applied successfully.
|
||||
// This delegate should support dynamic tensors, otherwise the second won't be
|
||||
@ -713,6 +744,44 @@ TEST_F(TestDelegate, TestResizeInputWithMultipleDelegates) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TestDelegate, TestFallbackWithMultipleDelegates) {
|
||||
// First delegate only supports node 0.
|
||||
// This delegate should support dynamic tensors, otherwise the second won't be
|
||||
// applied.
|
||||
delegate_ = std::unique_ptr<SimpleDelegate>(
|
||||
new SimpleDelegate({0}, kTfLiteDelegateFlagsAllowDynamicTensors));
|
||||
// Second delegate supports nodes 1 & 2, and makes the graph immutable.
|
||||
delegate2_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate(
|
||||
{1, 2}, kTfLiteDelegateFlagsNone, false /**fail_node_prepare**/,
|
||||
0 /**min_ops_per_subset**/, true /**fail_node_invoke**/));
|
||||
// Pre-delegation execution plan should have three nodes.
|
||||
ASSERT_EQ(interpreter_->execution_plan().size(), 3);
|
||||
ASSERT_EQ(
|
||||
interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()),
|
||||
kTfLiteOk);
|
||||
ASSERT_EQ(
|
||||
interpreter_->ModifyGraphWithDelegate(delegate2_->get_tf_lite_delegate()),
|
||||
kTfLiteOk);
|
||||
// Should be two delegates nodes.
|
||||
ASSERT_EQ(interpreter_->execution_plan().size(), 2);
|
||||
|
||||
std::vector<float> input = {1.0f, 2.0f, 3.0f};
|
||||
std::vector<float> expected_output = {2.0f, 4.0f, 6.0f};
|
||||
constexpr int kOutputTensorIndex = 2;
|
||||
TfLiteTensor* tensor = interpreter_->tensor(kOutputTensorIndex);
|
||||
|
||||
memcpy(interpreter_->typed_tensor<float>(0), input.data(), 3 * sizeof(float));
|
||||
memcpy(interpreter_->typed_tensor<float>(1), input.data(), 3 * sizeof(float));
|
||||
EXPECT_EQ(
|
||||
delegates::InterpreterUtils::InvokeWithCPUFallback(interpreter_.get()),
|
||||
kTfLiteDelegateError);
|
||||
// All delegates should be undone.
|
||||
EXPECT_EQ(interpreter_->execution_plan().size(), 3);
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
EXPECT_EQ(tensor->data.f[i], expected_output[i]) << i;
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TestDelegate, ReleaseNonPersistentMemoryWithDelegates) {
|
||||
// First delegate only supports node 0.
|
||||
// This delegate should support dynamic tensors, otherwise the second won't be
|
||||
|
65
tensorflow/lite/delegates/interpreter_utils.cc
Normal file
65
tensorflow/lite/delegates/interpreter_utils.cc
Normal file
@ -0,0 +1,65 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/delegates/interpreter_utils.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace delegates {
|
||||
TfLiteStatus InterpreterUtils::InvokeWithCPUFallback(Interpreter* interpreter) {
|
||||
TfLiteStatus status = interpreter->Invoke();
|
||||
if (status == kTfLiteOk || interpreter->IsCancelled() ||
|
||||
!interpreter->HasDelegates()) {
|
||||
return status;
|
||||
}
|
||||
// Retry without delegation.
|
||||
// TODO(b/138706191): retry only if error is due to delegation.
|
||||
TF_LITE_REPORT_ERROR(
|
||||
interpreter->error_reporter(),
|
||||
"Invoke() failed in the presence of delegation. Retrying without.");
|
||||
|
||||
// Copy input data to a buffer.
|
||||
// Input data is safe since Subgraph::PrepareOpsAndTensors() passes
|
||||
// preserve_inputs=true to ArenaPlanner.
|
||||
std::vector<char> buf;
|
||||
size_t input_size = 0;
|
||||
|
||||
for (auto i : interpreter->inputs()) {
|
||||
TF_LITE_ENSURE_STATUS(interpreter->EnsureTensorDataIsReadable(i));
|
||||
TfLiteTensor* t = interpreter->tensor(i);
|
||||
input_size += t->bytes;
|
||||
}
|
||||
buf.reserve(input_size);
|
||||
for (auto i : interpreter->inputs()) {
|
||||
TfLiteTensor* t = interpreter->tensor(i);
|
||||
buf.insert(buf.end(), t->data.raw, t->data.raw + t->bytes);
|
||||
}
|
||||
|
||||
TF_LITE_ENSURE_STATUS(interpreter->RemoveAllDelegates());
|
||||
|
||||
// Copy inputs from buffer.
|
||||
auto bufp = buf.begin();
|
||||
for (auto i : interpreter->inputs()) {
|
||||
TfLiteTensor* t = interpreter->tensor(i);
|
||||
std::copy(bufp, bufp + t->bytes, t->data.raw);
|
||||
bufp += t->bytes;
|
||||
}
|
||||
|
||||
// Invoke again.
|
||||
TF_LITE_ENSURE_STATUS(interpreter->Invoke());
|
||||
return kTfLiteDelegateError;
|
||||
}
|
||||
|
||||
} // namespace delegates
|
||||
} // namespace tflite
|
52
tensorflow/lite/delegates/interpreter_utils.h
Normal file
52
tensorflow/lite/delegates/interpreter_utils.h
Normal file
@ -0,0 +1,52 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_LITE_DELEGATES_INTERPRETER_UTILS_H_
|
||||
#define TENSORFLOW_LITE_DELEGATES_INTERPRETER_UTILS_H_
|
||||
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
|
||||
// Utility functions and classes for using delegates.
|
||||
|
||||
namespace tflite {
|
||||
namespace delegates {
|
||||
#if !TFLITE_EXPERIMENTAL_RUNTIME_EAGER
|
||||
class InterpreterUtils {
|
||||
public:
|
||||
/// Invokes an interpreter with automatic fallback from delegation to CPU.
|
||||
///
|
||||
/// If using the delegate fails, the delegate is automatically undone and an
|
||||
/// attempt made to return the interpreter to an invokable state.
|
||||
///
|
||||
/// Allowing the fallback is suitable only if both of the following hold:
|
||||
/// - The caller is known not to cache pointers to tensor data across Invoke()
|
||||
/// calls.
|
||||
/// - The model is not stateful (no variables, no LSTMs) or the state isn't
|
||||
/// needed between batches.
|
||||
///
|
||||
/// Returns one of the following three status codes:
|
||||
/// 1. kTfLiteOk: Success. Output is valid.
|
||||
/// 2. kTfLiteDelegateError: Delegate error but fallback succeeded. Output is
|
||||
/// valid.
|
||||
/// NOTE: This undoes all delegates previously applied to the Interpreter.
|
||||
/// 3. kTfLiteError: Unexpected/runtime failure. Output is invalid.
|
||||
/// WARNING: This is an experimental API and subject to change.
|
||||
static TfLiteStatus InvokeWithCPUFallback(Interpreter* interpreter);
|
||||
};
|
||||
#endif // !TFLITE_EXPERIMENTAL_RUNTIME_EAGER
|
||||
} // namespace delegates
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_DELEGATES_INTERPRETER_UTILS_H_
|
@ -16,6 +16,8 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_LITE_DELEGATES_UTILS_H_
|
||||
#define TENSORFLOW_LITE_DELEGATES_UTILS_H_
|
||||
|
||||
// Utility functions and classes for implementing delegates.
|
||||
|
||||
#include <functional>
|
||||
#include <limits>
|
||||
#include <set>
|
||||
|
@ -310,6 +310,8 @@ void Interpreter::SetCancellationFunction(void* data,
|
||||
}
|
||||
}
|
||||
|
||||
bool Interpreter::IsCancelled() { return primary_subgraph().IsCancelled(); }
|
||||
|
||||
TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate) {
|
||||
TfLiteStatus status = kTfLiteOk;
|
||||
for (auto& subgraph : subgraphs_) {
|
||||
@ -340,6 +342,8 @@ TfLiteStatus Interpreter::RemoveAllDelegates() {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
bool Interpreter::HasDelegates() { return primary_subgraph().HasDelegates(); }
|
||||
|
||||
TfLiteStatus Interpreter::SetBufferHandle(int tensor_index,
|
||||
TfLiteBufferHandle buffer_handle,
|
||||
TfLiteDelegate* delegate) {
|
||||
|
@ -42,6 +42,9 @@ namespace tflite {
|
||||
|
||||
class InterpreterTest;
|
||||
class TestDelegate;
|
||||
namespace delegates {
|
||||
class InterpreterUtils; // Class for friend declarations.
|
||||
} // namespace delegates
|
||||
|
||||
namespace impl {
|
||||
|
||||
@ -529,6 +532,7 @@ class Interpreter {
|
||||
friend class InterpreterBuilder;
|
||||
friend class tflite::InterpreterTest;
|
||||
friend class tflite::TestDelegate;
|
||||
friend class tflite::delegates::InterpreterUtils;
|
||||
|
||||
/// Set the value of an external context.
|
||||
static void SetExternalContext(struct TfLiteContext* context,
|
||||
@ -542,6 +546,15 @@ class Interpreter {
|
||||
// afterwards.
|
||||
TfLiteStatus RemoveAllDelegates();
|
||||
|
||||
// Returns true if delegates have been applied.
|
||||
bool HasDelegates();
|
||||
|
||||
// Returns true if cancellation function returns true.
|
||||
bool IsCancelled();
|
||||
|
||||
// Get the error reporter associated with this interpreter.
|
||||
ErrorReporter* error_reporter() { return error_reporter_; }
|
||||
|
||||
// A pure C data structure used to communicate with the pure C plugin
|
||||
// interface. To avoid copying tensor metadata, this is also the definitive
|
||||
// structure to store tensors.
|
||||
|
Loading…
Reference in New Issue
Block a user