Add flag for enabling delegate fallback to the TFLite C API.

This adds a new experimental flag to the C API interpreter options,
which provides functionality equivalent to the TF Lite C++ API's
tflite::delegates::InterpreterUtils::InvokeWithCPUFallback
from delegates/interpreter_utils.h.
PiperOrigin-RevId: 345757574
Change-Id: I91ee063babfc25a793535f7dd9d4541810e11650
This commit is contained in:
Fergus Henderson 2020-12-04 14:28:02 -08:00 committed by TensorFlower Gardener
parent 1dfcf5cdd7
commit f42bd2184c
10 changed files with 248 additions and 23 deletions

View File

@ -64,6 +64,7 @@ cc_library(
"//tensorflow/lite:builtin_ops",
"//tensorflow/lite:framework",
"//tensorflow/lite:version",
"//tensorflow/lite/delegates:interpreter_utils",
"//tensorflow/lite/delegates/nnapi:nnapi_delegate",
"//tensorflow/lite/kernels:builtin_ops",
"//tensorflow/lite/kernels/internal:compatibility",
@ -114,6 +115,7 @@ cc_test(
":c_api_experimental",
":common",
"//tensorflow/lite:kernel_api",
"//tensorflow/lite/delegates:delegate_test_util",
"//tensorflow/lite/testing:util",
"@com_google_googletest//:gtest",
],

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/lite/builtin_ops.h"
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/delegates/interpreter_utils.h"
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
#include "tensorflow/lite/error_reporter.h"
#include "tensorflow/lite/interpreter.h"
@ -165,8 +166,13 @@ TfLiteStatus TfLiteInterpreterAllocateTensors(TfLiteInterpreter* interpreter) {
}
TfLiteStatus TfLiteInterpreterInvoke(TfLiteInterpreter* interpreter) {
if (interpreter->enable_delegate_fallback) {
return tflite::delegates::InterpreterUtils::InvokeWithCPUFallback(
interpreter->impl.get());
} else {
return interpreter->impl->Invoke();
}
}
int32_t TfLiteInterpreterGetOutputTensorCount(
const TfLiteInterpreter* interpreter) {
@ -298,8 +304,12 @@ TfLiteInterpreter* InterpreterCreateWithOpResolver(
}
}
bool enable_delegate_fallback =
optional_options != nullptr && optional_options->enable_delegate_fallback;
return new TfLiteInterpreter{model->impl, std::move(optional_error_reporter),
std::move(interpreter)};
std::move(interpreter),
enable_delegate_fallback};
}
} // namespace internal

View File

@ -177,9 +177,34 @@ TFL_CAPI_EXPORT extern TfLiteStatus TfLiteInterpreterAllocateTensors(
// Runs inference for the loaded graph.
//
// Before calling this function, the caller should first invoke
// TfLiteInterpreterAllocateTensors() and should also set the values for the
// input tensors. After successfully calling this function, the values for the
// output tensors will be set.
//
// NOTE: It is possible that the interpreter is not in a ready state to
// evaluate (e.g., if a ResizeInputTensor() has been performed without a call to
// evaluate (e.g., if AllocateTensors() hasn't been called, or if a
// ResizeInputTensor() has been performed without a subsequent call to
// AllocateTensors()).
//
// If the (experimental!) delegate fallback option was enabled in the
// interpreter options, then the interpreter will automatically fall back to
// not using any delegates if execution with delegates fails. For details, see
// TfLiteInterpreterOptionsSetEnableDelegateFallback in c_api_experimental.h.
//
// Returns one of the following status codes:
// - kTfLiteOk: Success. Output is valid.
// - kTfLiteDelegateError: Execution with delegates failed, due to a problem
// with the delegate(s). If fallback was not enabled, output is invalid.
// If fallback was enabled, this return value indicates that fallback
// succeeded, the output is valid, and all delegates previously applied to
// the interpreter have been undone.
// - kTfLiteApplicationError: Same as for kTfLiteDelegateError, except that
// the problem was not with the delegate itself, but rather was
// due to an incompatibility between the delegate(s) and the
// interpreter or model.
// - kTfLiteError: Unexpected/runtime failure. Output is invalid.
TFL_CAPI_EXPORT extern TfLiteStatus TfLiteInterpreterInvoke(
TfLiteInterpreter* interpreter);

View File

@ -77,6 +77,11 @@ void TfLiteInterpreterOptionsSetUseNNAPI(TfLiteInterpreterOptions* options,
options->use_nnapi = enable;
}
void TfLiteInterpreterOptionsSetEnableDelegateFallback(
TfLiteInterpreterOptions* options, bool enable) {
options->enable_delegate_fallback = enable;
}
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus

View File

@ -113,12 +113,39 @@ TFL_CAPI_EXPORT extern TfLiteInterpreter*
TfLiteInterpreterCreateWithSelectedOps(const TfLiteModel* model,
const TfLiteInterpreterOptions* options);
/// Enable or disable the NN API for the interpreter (true to enable).
/// Enable or disable the NN API delegate for the interpreter (true to enable).
///
/// WARNING: This is an experimental API and subject to change.
TFL_CAPI_EXPORT extern void TfLiteInterpreterOptionsSetUseNNAPI(
TfLiteInterpreterOptions* options, bool enable);
/// Enable or disable CPU fallback for the interpreter (true to enable).
/// If enabled, TfLiteInterpreterInvoke will do automatic fallback from
/// executing with delegate(s) to regular execution without delegates
/// (i.e. on CPU).
///
/// Allowing the fallback is suitable only if both of the following hold:
/// - The caller is known not to cache pointers to tensor data across
/// TfLiteInterpreterInvoke calls.
/// - The model is not stateful (no variables, no LSTMs) or the state isn't
/// needed between batches.
///
/// When delegate fallback is enabled, TfLiteInterpreterInvoke will
/// behave as follows:
/// If one or more delegates were set in the interpreter options
/// (see TfLiteInterpreterOptionsAddDelegate),
/// AND inference fails,
/// then the interpreter will fall back to not using any delegates.
/// In that case, the previously applied delegate(s) will be automatically
/// undone, and an attempt will be made to return the interpreter to an
/// invokable state, which may invalidate previous tensor addresses,
/// and the inference will be attempted again, using input tensors with
/// the same value as previously set.
///
/// WARNING: This is an experimental API and subject to change.
TFL_CAPI_EXPORT extern void TfLiteInterpreterOptionsSetEnableDelegateFallback(
TfLiteInterpreterOptions* options, bool enable);
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus

View File

@ -15,12 +15,20 @@ limitations under the License.
#include "tensorflow/lite/c/c_api_experimental.h"
#include <string.h>
#include <memory>
#include <vector>
#include <gtest/gtest.h>
#include "tensorflow/lite/builtin_ops.h"
#include "tensorflow/lite/c/c_api.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/delegates/delegate_test_util.h"
#include "tensorflow/lite/testing/util.h"
using tflite::delegates::test_utils::TestDelegate;
namespace {
const TfLiteRegistration* GetDummyRegistration() {
@ -159,6 +167,130 @@ TEST(CApiExperimentalTest, SetOpResolver) {
TfLiteModelDelete(model);
}
void AllocateAndSetInputs(TfLiteInterpreter* interpreter) {
std::array<int, 1> input_dims = {2};
ASSERT_EQ(TfLiteInterpreterResizeInputTensor(
interpreter, 0, input_dims.data(), input_dims.size()),
kTfLiteOk);
ASSERT_EQ(TfLiteInterpreterAllocateTensors(interpreter), kTfLiteOk);
TfLiteTensor* input_tensor = TfLiteInterpreterGetInputTensor(interpreter, 0);
ASSERT_NE(input_tensor, nullptr);
std::array<float, 2> input = {1.f, 3.f};
ASSERT_EQ(TfLiteTensorCopyFromBuffer(input_tensor, input.data(),
input.size() * sizeof(float)),
kTfLiteOk);
}
void VerifyOutputs(TfLiteInterpreter* interpreter) {
const TfLiteTensor* output_tensor =
TfLiteInterpreterGetOutputTensor(interpreter, 0);
ASSERT_NE(output_tensor, nullptr);
std::array<float, 2> output;
ASSERT_EQ(TfLiteTensorCopyToBuffer(output_tensor, output.data(),
output.size() * sizeof(float)),
kTfLiteOk);
EXPECT_EQ(output[0], 3.f);
EXPECT_EQ(output[1], 9.f);
}
void CheckExecution(TfLiteInterpreterOptions* options,
TfLiteStatus expected_first_result,
TfLiteStatus expected_subsequent_results) {
TfLiteModel* model =
TfLiteModelCreateFromFile("tensorflow/lite/testdata/add.bin");
ASSERT_NE(model, nullptr);
TfLiteInterpreter* interpreter = TfLiteInterpreterCreate(model, options);
ASSERT_NE(interpreter, nullptr);
AllocateAndSetInputs(interpreter);
for (int i = 0; i < 4; i++) {
bool result = TfLiteInterpreterInvoke(interpreter);
bool expected_result =
((i == 0) ? expected_first_result : expected_subsequent_results);
EXPECT_EQ(result, expected_result);
if (result != kTfLiteError) {
VerifyOutputs(interpreter);
}
}
TfLiteInterpreterDelete(interpreter);
TfLiteModelDelete(model);
}
TEST_F(TestDelegate, NoDelegate) {
TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate();
// Execution without any delegate should succeed.
CheckExecution(options, kTfLiteOk, kTfLiteOk);
TfLiteInterpreterOptionsDelete(options);
}
TEST_F(TestDelegate, DelegateNodeInvokeFailure) {
// Initialize a delegate that will fail when invoked.
delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate(
{0, 1}, kTfLiteDelegateFlagsNone, false /**fail_node_prepare**/,
0 /**min_ops_per_subset**/, true /**fail_node_invoke**/,
false /**automatic_shape_propagation**/, false /**custom_op**/));
// Create another interpreter with the delegate, without fallback.
TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate();
TfLiteInterpreterOptionsAddDelegate(options,
delegate_->get_tf_lite_delegate());
// Execution with the delegate should fail.
CheckExecution(options, kTfLiteError, kTfLiteError);
TfLiteInterpreterOptionsDelete(options);
}
TEST_F(TestDelegate, DelegateNodeInvokeFailureFallback) {
// Initialize a delegate that will fail when invoked.
delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate(
{0, 1}, kTfLiteDelegateFlagsNone, false /**fail_node_prepare**/,
0 /**min_ops_per_subset**/, true /**fail_node_invoke**/,
false /**automatic_shape_propagation**/, false /**custom_op**/));
// Create another interpreter with the delegate, with fallback enabled.
TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate();
TfLiteInterpreterOptionsAddDelegate(options,
delegate_->get_tf_lite_delegate());
TfLiteInterpreterOptionsSetEnableDelegateFallback(options, true);
CheckExecution(options,
// First execution will report DelegateError which indicates
// that the delegate failed but fallback succeeded.
kTfLiteDelegateError,
// Subsequent executions will not use the delegate and
// should therefore succeed.
kTfLiteOk);
TfLiteInterpreterOptionsDelete(options);
}
TEST_F(TestDelegate, TestFallbackWithMultipleDelegates) {
// First delegate only supports node 0.
// This delegate should support dynamic tensors, otherwise the second won't be
// applied.
delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate(
{0}, kTfLiteDelegateFlagsAllowDynamicTensors,
false /**fail_node_prepare**/, 0 /**min_ops_per_subset**/,
true /**fail_node_invoke**/, false /**automatic_shape_propagation**/,
false /**custom_op**/));
// Second delegate supports node 1, and makes the graph immutable.
delegate2_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate(
{1}, kTfLiteDelegateFlagsNone, false /**fail_node_prepare**/,
0 /**min_ops_per_subset**/, true /**fail_node_invoke**/,
false /**automatic_shape_propagation**/, false /**custom_op**/));
TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate();
TfLiteInterpreterOptionsAddDelegate(options,
delegate_->get_tf_lite_delegate());
TfLiteInterpreterOptionsAddDelegate(options,
delegate2_->get_tf_lite_delegate());
TfLiteInterpreterOptionsSetEnableDelegateFallback(options, true);
CheckExecution(options,
// First execution will report DelegateError which indicates
// that the delegate failed but fallback succeeded.
kTfLiteDelegateError,
// Subsequent executions will not use the delegate and
// should therefore succeed.
kTfLiteOk);
TfLiteInterpreterOptionsDelete(options);
}
} // namespace
int main(int argc, char** argv) {

View File

@ -87,6 +87,12 @@ struct TfLiteInterpreterOptions {
TfLiteErrorReporterCallback error_reporter_callback;
bool use_nnapi = false;
// Determines whether to allow automatic fallback to CPU.
// If true, and if one or more delegates were set,
// then if Invoke with delegates fails, it will be
// automatically retried without delegates.
bool enable_delegate_fallback = false;
};
struct TfLiteInterpreter {
@ -100,6 +106,8 @@ struct TfLiteInterpreter {
std::unique_ptr<tflite::ErrorReporter> optional_error_reporter;
std::unique_ptr<tflite::Interpreter> impl;
bool enable_delegate_fallback;
};
namespace tflite {

View File

@ -113,17 +113,16 @@ void TestDelegate::TearDown() {
delegate_.reset();
}
TestDelegate::SimpleDelegate::SimpleDelegate(const std::vector<int>& nodes,
int64_t delegate_flags,
bool fail_node_prepare,
int min_ops_per_subset,
bool fail_node_invoke,
bool automatic_shape_propagation)
TestDelegate::SimpleDelegate::SimpleDelegate(
const std::vector<int>& nodes, int64_t delegate_flags,
bool fail_node_prepare, int min_ops_per_subset, bool fail_node_invoke,
bool automatic_shape_propagation, bool custom_op)
: nodes_(nodes),
fail_delegate_node_prepare_(fail_node_prepare),
min_ops_per_subset_(min_ops_per_subset),
fail_delegate_node_invoke_(fail_node_invoke),
automatic_shape_propagation_(automatic_shape_propagation) {
automatic_shape_propagation_(automatic_shape_propagation),
custom_op_(custom_op) {
delegate_.Prepare = [](TfLiteContext* context,
TfLiteDelegate* delegate) -> TfLiteStatus {
auto* simple = static_cast<SimpleDelegate*>(delegate->data_);
@ -137,8 +136,12 @@ TestDelegate::SimpleDelegate::SimpleDelegate(const std::vector<int>& nodes,
TfLiteNode* node;
TfLiteRegistration* reg;
context->GetNodeAndRegistration(context, node_index, &node, &reg);
if (simple->custom_op_) {
TFLITE_CHECK_EQ(reg->builtin_code, tflite::BuiltinOperator_CUSTOM);
TFLITE_CHECK_EQ(strcmp(reg->custom_name, "my_add"), 0);
} else {
TFLITE_CHECK_EQ(reg->builtin_code, tflite::BuiltinOperator_ADD);
}
}
// Check that all nodes are available
TfLiteIntArray* execution_plan;
@ -150,8 +153,12 @@ TestDelegate::SimpleDelegate::SimpleDelegate(const std::vector<int>& nodes,
context->GetNodeAndRegistration(context, node_index, &node, &reg);
if (exec_index == node_index) {
// Check op details only if it wasn't delegated already.
if (simple->custom_op_) {
TFLITE_CHECK_EQ(reg->builtin_code, tflite::BuiltinOperator_CUSTOM);
TFLITE_CHECK_EQ(strcmp(reg->custom_name, "my_add"), 0);
} else {
TFLITE_CHECK_EQ(reg->builtin_code, tflite::BuiltinOperator_ADD);
}
}
}
@ -182,7 +189,7 @@ TestDelegate::SimpleDelegate::SimpleDelegate(const std::vector<int>& nodes,
sizeof(int) * nodes_to_separate->size);
}
// Another call to PreviewDelegateParitioning should be okay, since
// Another call to PreviewDelegatePartitioning should be okay, since
// partitioning memory is managed by context.
TFLITE_CHECK_EQ(
context->PreviewDelegatePartitioning(context, nodes_to_separate,

View File

@ -54,18 +54,25 @@ class TestDelegate : public ::testing::Test {
// Create a simple implementation of a TfLiteDelegate. We use the C++ class
// SimpleDelegate and it can produce a handle TfLiteDelegate that is
// value-copyable and compatible with TfLite.
//
// Parameters:
// nodes: Indices of the graph nodes that the delegate will handle.
// fail_node_prepare: To simulate failure of Delegate node's Prepare().
// min_ops_per_subset: If >0, partitioning preview is used to choose only
// those subsets with min_ops_per_subset number of nodes.
// fail_node_invoke: To simulate failure of Delegate node's Invoke().
// automatic_shape_propagation: This assumes that the runtime will propagate
// shapes using the original execution plan.
// automatic_shape_propagation: This assumes that the runtime will
// propagate shapes using the original execution plan.
// custom_op: If true, the graph nodes specified in the 'nodes' parameter
// should be custom ops with name "my_add"; if false, they should be
// the builtin ADD operator.
explicit SimpleDelegate(const std::vector<int>& nodes,
int64_t delegate_flags = kTfLiteDelegateFlagsNone,
bool fail_node_prepare = false,
int min_ops_per_subset = 0,
bool fail_node_invoke = false,
bool automatic_shape_propagation = false);
bool automatic_shape_propagation = false,
bool custom_op = true);
TfLiteRegistration FakeFusedRegistration();
@ -80,6 +87,7 @@ class TestDelegate : public ::testing::Test {
int min_ops_per_subset_ = 0;
bool fail_delegate_node_invoke_ = false;
bool automatic_shape_propagation_ = false;
bool custom_op_ = true;
};
std::unique_ptr<Interpreter> interpreter_;

View File

@ -137,6 +137,7 @@ $(wildcard tensorflow/lite/c/*.c) \
$(wildcard tensorflow/lite/c/*.cc) \
$(wildcard tensorflow/lite/core/*.cc) \
$(wildcard tensorflow/lite/core/api/*.cc) \
$(wildcard tensorflow/lite/delegates/interpreter_utils.cc) \
$(wildcard tensorflow/lite/experimental/resource/*.cc) \
$(wildcard tensorflow/lite/schema/schema_utils.cc) \
$(wildcard tensorflow/lite/tools/make/downloads/ruy/ruy/*.cc)