Roll forward API for automatic fallback on delegation failure.

PiperOrigin-RevId: 313353207
Change-Id: I0f7824ecc5421a179c10a6de4fc5192e9815abb7
This commit is contained in:
A. Unique TensorFlower 2020-05-27 02:22:10 -07:00 committed by TensorFlower Gardener
parent 9084090e8d
commit 7738c1818e
10 changed files with 234 additions and 3 deletions

View File

@ -376,7 +376,9 @@ cc_test(
cc_test(
name = "interpreter_test",
size = "small",
srcs = ["interpreter_test.cc"],
srcs = [
"interpreter_test.cc",
],
features = ["-dynamic_link_test_srcs"], # see go/dynamic_link_test_srcs
tags = [
"tflite_not_portable_ios", # TODO(b/117786830)

View File

@ -533,6 +533,11 @@ void Subgraph::SetCancellationFunction(void* data,
check_cancelled_func_ = check_cancelled_func;
}
bool Subgraph::IsCancelled() {
return (check_cancelled_func_ != nullptr) &&
(*check_cancelled_func_)(cancellation_data_);
}
void Subgraph::ReserveNodes(int count) {
nodes_and_registration_.reserve(count);
}
@ -1316,6 +1321,8 @@ TfLiteStatus Subgraph::RemoveAllDelegates() {
return kTfLiteOk;
}
bool Subgraph::HasDelegates() { return !delegates_applied_.empty(); }
TfLiteStatus Subgraph::EnsureMemoryAllocations() {
if (memory_planner_) {
state_ = kStateUninvokable;

View File

@ -553,6 +553,9 @@ class Subgraph {
// afterwards.
TfLiteStatus RemoveAllDelegates();
// Returns true if the subgraph has delegates applied.
bool HasDelegates();
// Cleanups up data reserved for the given node. Does not remove the {node,
// registration} pair from nodes_and_registrations_.
void CleanupNode(int node_index);
@ -578,6 +581,9 @@ class Subgraph {
// Ensures the memory required is planned and allocated.
TfLiteStatus EnsureMemoryAllocations();
// Returns true if cancellation function returns true.
bool IsCancelled();
// The state of the Interpreter.
enum State {
// The interpreter isn't ready to be invoked.

View File

@ -32,6 +32,16 @@ cc_library(
],
)
cc_library(
name = "interpreter_utils",
srcs = ["interpreter_utils.cc"],
hdrs = ["interpreter_utils.h"],
copts = tflite_copts(),
deps = [
"//tensorflow/lite:framework",
],
)
cc_test(
name = "utils_test",
srcs = ["utils_test.cc"],
@ -53,6 +63,7 @@ cc_test(
"tflite_not_portable_ios", # TODO(b/117786830)
],
deps = [
":interpreter_utils",
"//tensorflow/lite:framework",
"//tensorflow/lite:version",
"//tensorflow/lite/core/api",

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/lite/delegates/interpreter_utils.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/kernel_util.h"
@ -261,8 +262,10 @@ class TestDelegate : public ::testing::Test {
for (int i = 0; i < num; i++) {
out->data.f[i] = a0->data.f[i] + a1->data.f[i];
}
// Make the data stale so that CopyFromBufferHandle can be invoked
out->data_is_stale = true;
if (out->buffer_handle != kTfLiteNullBufferHandle) {
// Make the data stale so that CopyFromBufferHandle can be invoked
out->data_is_stale = true;
}
return kTfLiteOk;
};
if (fail_delegate_node_invoke_) {
@ -397,6 +400,34 @@ TEST_F(TestDelegate, DelegateNodeInvokeFailure) {
}
}
TEST_F(TestDelegate, DelegateNodeInvokeFailureFallback) {
delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate(
{0, 1, 2}, kTfLiteDelegateFlagsNone, false /**fail_node_prepare**/,
0 /**min_ops_per_subset**/, true /**fail_node_invoke**/));
ASSERT_EQ(
interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()),
kTfLiteOk);
// Delegation modified execution plan.
ASSERT_EQ(interpreter_->execution_plan().size(), 1);
std::vector<float> input = {1.0f, 2.0f, 3.0f};
std::vector<float> expected_output = {2.0f, 4.0f, 6.0f};
constexpr int kOutputTensorIndex = 3;
memcpy(interpreter_->typed_tensor<float>(0), input.data(), 3 * sizeof(float));
memcpy(interpreter_->typed_tensor<float>(1), input.data(), 3 * sizeof(float));
EXPECT_EQ(
delegates::InterpreterUtils::InvokeWithCPUFallback(interpreter_.get()),
kTfLiteDelegateError);
// Delegation removed, returning to original execution plan.
ASSERT_EQ(interpreter_->execution_plan().size(), 3);
// Check outputs.
TfLiteTensor* tensor = interpreter_->tensor(kOutputTensorIndex);
for (int i = 0; i < 3; ++i) {
EXPECT_EQ(tensor->data.f[i], expected_output[i]) << i;
}
}
TEST_F(TestDelegate, SecondDelegationPrepareFailure) {
// First delegate only supports nodes 1, 2. Gets applied successfully.
// This delegate should support dynamic tensors, otherwise the second won't be
@ -713,6 +744,44 @@ TEST_F(TestDelegate, TestResizeInputWithMultipleDelegates) {
}
}
TEST_F(TestDelegate, TestFallbackWithMultipleDelegates) {
// First delegate only supports node 0.
// This delegate should support dynamic tensors, otherwise the second won't be
// applied.
delegate_ = std::unique_ptr<SimpleDelegate>(
new SimpleDelegate({0}, kTfLiteDelegateFlagsAllowDynamicTensors));
// Second delegate supports nodes 1 & 2, and makes the graph immutable.
delegate2_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate(
{1, 2}, kTfLiteDelegateFlagsNone, false /**fail_node_prepare**/,
0 /**min_ops_per_subset**/, true /**fail_node_invoke**/));
// Pre-delegation execution plan should have three nodes.
ASSERT_EQ(interpreter_->execution_plan().size(), 3);
ASSERT_EQ(
interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()),
kTfLiteOk);
ASSERT_EQ(
interpreter_->ModifyGraphWithDelegate(delegate2_->get_tf_lite_delegate()),
kTfLiteOk);
// Should be two delegates nodes.
ASSERT_EQ(interpreter_->execution_plan().size(), 2);
std::vector<float> input = {1.0f, 2.0f, 3.0f};
std::vector<float> expected_output = {2.0f, 4.0f, 6.0f};
constexpr int kOutputTensorIndex = 2;
TfLiteTensor* tensor = interpreter_->tensor(kOutputTensorIndex);
memcpy(interpreter_->typed_tensor<float>(0), input.data(), 3 * sizeof(float));
memcpy(interpreter_->typed_tensor<float>(1), input.data(), 3 * sizeof(float));
EXPECT_EQ(
delegates::InterpreterUtils::InvokeWithCPUFallback(interpreter_.get()),
kTfLiteDelegateError);
// All delegates should be undone.
EXPECT_EQ(interpreter_->execution_plan().size(), 3);
for (int i = 0; i < 3; ++i) {
EXPECT_EQ(tensor->data.f[i], expected_output[i]) << i;
}
}
TEST_F(TestDelegate, ReleaseNonPersistentMemoryWithDelegates) {
// First delegate only supports node 0.
// This delegate should support dynamic tensors, otherwise the second won't be

View File

@ -0,0 +1,65 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/interpreter_utils.h"
namespace tflite {
namespace delegates {
TfLiteStatus InterpreterUtils::InvokeWithCPUFallback(Interpreter* interpreter) {
TfLiteStatus status = interpreter->Invoke();
if (status == kTfLiteOk || interpreter->IsCancelled() ||
!interpreter->HasDelegates()) {
return status;
}
// Retry without delegation.
// TODO(b/138706191): retry only if error is due to delegation.
TF_LITE_REPORT_ERROR(
interpreter->error_reporter(),
"Invoke() failed in the presence of delegation. Retrying without.");
// Copy input data to a buffer.
// Input data is safe since Subgraph::PrepareOpsAndTensors() passes
// preserve_inputs=true to ArenaPlanner.
std::vector<char> buf;
size_t input_size = 0;
for (auto i : interpreter->inputs()) {
TF_LITE_ENSURE_STATUS(interpreter->EnsureTensorDataIsReadable(i));
TfLiteTensor* t = interpreter->tensor(i);
input_size += t->bytes;
}
buf.reserve(input_size);
for (auto i : interpreter->inputs()) {
TfLiteTensor* t = interpreter->tensor(i);
buf.insert(buf.end(), t->data.raw, t->data.raw + t->bytes);
}
TF_LITE_ENSURE_STATUS(interpreter->RemoveAllDelegates());
// Copy inputs from buffer.
auto bufp = buf.begin();
for (auto i : interpreter->inputs()) {
TfLiteTensor* t = interpreter->tensor(i);
std::copy(bufp, bufp + t->bytes, t->data.raw);
bufp += t->bytes;
}
// Invoke again.
TF_LITE_ENSURE_STATUS(interpreter->Invoke());
return kTfLiteDelegateError;
}
} // namespace delegates
} // namespace tflite

View File

@ -0,0 +1,52 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_DELEGATES_INTERPRETER_UTILS_H_
#define TENSORFLOW_LITE_DELEGATES_INTERPRETER_UTILS_H_
#include "tensorflow/lite/interpreter.h"
// Utility functions and classes for using delegates.
namespace tflite {
namespace delegates {
#if !TFLITE_EXPERIMENTAL_RUNTIME_EAGER
class InterpreterUtils {
public:
/// Invokes an interpreter with automatic fallback from delegation to CPU.
///
/// If using the delegate fails, the delegate is automatically undone and an
/// attempt made to return the interpreter to an invokable state.
///
/// Allowing the fallback is suitable only if both of the following hold:
/// - The caller is known not to cache pointers to tensor data across Invoke()
/// calls.
/// - The model is not stateful (no variables, no LSTMs) or the state isn't
/// needed between batches.
///
/// Returns one of the following three status codes:
/// 1. kTfLiteOk: Success. Output is valid.
/// 2. kTfLiteDelegateError: Delegate error but fallback succeeded. Output is
/// valid.
/// NOTE: This undoes all delegates previously applied to the Interpreter.
/// 3. kTfLiteError: Unexpected/runtime failure. Output is invalid.
/// WARNING: This is an experimental API and subject to change.
static TfLiteStatus InvokeWithCPUFallback(Interpreter* interpreter);
};
#endif // !TFLITE_EXPERIMENTAL_RUNTIME_EAGER
} // namespace delegates
} // namespace tflite
#endif // TENSORFLOW_LITE_DELEGATES_INTERPRETER_UTILS_H_

View File

@ -16,6 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_DELEGATES_UTILS_H_
#define TENSORFLOW_LITE_DELEGATES_UTILS_H_
// Utility functions and classes for implementing delegates.
#include <functional>
#include <limits>
#include <set>

View File

@ -310,6 +310,8 @@ void Interpreter::SetCancellationFunction(void* data,
}
}
bool Interpreter::IsCancelled() { return primary_subgraph().IsCancelled(); }
TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate) {
TfLiteStatus status = kTfLiteOk;
for (auto& subgraph : subgraphs_) {
@ -340,6 +342,8 @@ TfLiteStatus Interpreter::RemoveAllDelegates() {
return kTfLiteOk;
}
bool Interpreter::HasDelegates() { return primary_subgraph().HasDelegates(); }
TfLiteStatus Interpreter::SetBufferHandle(int tensor_index,
TfLiteBufferHandle buffer_handle,
TfLiteDelegate* delegate) {

View File

@ -42,6 +42,9 @@ namespace tflite {
class InterpreterTest;
class TestDelegate;
namespace delegates {
class InterpreterUtils; // Class for friend declarations.
} // namespace delegates
namespace impl {
@ -529,6 +532,7 @@ class Interpreter {
friend class InterpreterBuilder;
friend class tflite::InterpreterTest;
friend class tflite::TestDelegate;
friend class tflite::delegates::InterpreterUtils;
/// Set the value of an external context.
static void SetExternalContext(struct TfLiteContext* context,
@ -542,6 +546,15 @@ class Interpreter {
// afterwards.
TfLiteStatus RemoveAllDelegates();
// Returns true if delegates have been applied.
bool HasDelegates();
// Returns true if cancellation function returns true.
bool IsCancelled();
// Get the error reporter associated with this interpreter.
ErrorReporter* error_reporter() { return error_reporter_; }
// A pure C data structure used to communicate with the pure C plugin
// interface. To avoid copying tensor metadata, this is also the definitive
// structure to store tensors.