Add flag for enabling delegate fallback to the TFLite C API.
This adds a new experimental flag to the C API interpreter options, which provides functionality equivalent to the TF Lite C++ API's tflite::delegates::InterpreterUtils::InvokeWithCPUFallback from delegates/interpreter_utils.h. PiperOrigin-RevId: 345757574 Change-Id: I91ee063babfc25a793535f7dd9d4541810e11650
This commit is contained in:
parent
1dfcf5cdd7
commit
f42bd2184c
@ -64,6 +64,7 @@ cc_library(
|
|||||||
"//tensorflow/lite:builtin_ops",
|
"//tensorflow/lite:builtin_ops",
|
||||||
"//tensorflow/lite:framework",
|
"//tensorflow/lite:framework",
|
||||||
"//tensorflow/lite:version",
|
"//tensorflow/lite:version",
|
||||||
|
"//tensorflow/lite/delegates:interpreter_utils",
|
||||||
"//tensorflow/lite/delegates/nnapi:nnapi_delegate",
|
"//tensorflow/lite/delegates/nnapi:nnapi_delegate",
|
||||||
"//tensorflow/lite/kernels:builtin_ops",
|
"//tensorflow/lite/kernels:builtin_ops",
|
||||||
"//tensorflow/lite/kernels/internal:compatibility",
|
"//tensorflow/lite/kernels/internal:compatibility",
|
||||||
@ -114,6 +115,7 @@ cc_test(
|
|||||||
":c_api_experimental",
|
":c_api_experimental",
|
||||||
":common",
|
":common",
|
||||||
"//tensorflow/lite:kernel_api",
|
"//tensorflow/lite:kernel_api",
|
||||||
|
"//tensorflow/lite/delegates:delegate_test_util",
|
||||||
"//tensorflow/lite/testing:util",
|
"//tensorflow/lite/testing:util",
|
||||||
"@com_google_googletest//:gtest",
|
"@com_google_googletest//:gtest",
|
||||||
],
|
],
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/lite/builtin_ops.h"
|
#include "tensorflow/lite/builtin_ops.h"
|
||||||
#include "tensorflow/lite/c/c_api_internal.h"
|
#include "tensorflow/lite/c/c_api_internal.h"
|
||||||
|
#include "tensorflow/lite/delegates/interpreter_utils.h"
|
||||||
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
|
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
|
||||||
#include "tensorflow/lite/error_reporter.h"
|
#include "tensorflow/lite/error_reporter.h"
|
||||||
#include "tensorflow/lite/interpreter.h"
|
#include "tensorflow/lite/interpreter.h"
|
||||||
@ -165,8 +166,13 @@ TfLiteStatus TfLiteInterpreterAllocateTensors(TfLiteInterpreter* interpreter) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TfLiteStatus TfLiteInterpreterInvoke(TfLiteInterpreter* interpreter) {
|
TfLiteStatus TfLiteInterpreterInvoke(TfLiteInterpreter* interpreter) {
|
||||||
|
if (interpreter->enable_delegate_fallback) {
|
||||||
|
return tflite::delegates::InterpreterUtils::InvokeWithCPUFallback(
|
||||||
|
interpreter->impl.get());
|
||||||
|
} else {
|
||||||
return interpreter->impl->Invoke();
|
return interpreter->impl->Invoke();
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
int32_t TfLiteInterpreterGetOutputTensorCount(
|
int32_t TfLiteInterpreterGetOutputTensorCount(
|
||||||
const TfLiteInterpreter* interpreter) {
|
const TfLiteInterpreter* interpreter) {
|
||||||
@ -298,8 +304,12 @@ TfLiteInterpreter* InterpreterCreateWithOpResolver(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool enable_delegate_fallback =
|
||||||
|
optional_options != nullptr && optional_options->enable_delegate_fallback;
|
||||||
|
|
||||||
return new TfLiteInterpreter{model->impl, std::move(optional_error_reporter),
|
return new TfLiteInterpreter{model->impl, std::move(optional_error_reporter),
|
||||||
std::move(interpreter)};
|
std::move(interpreter),
|
||||||
|
enable_delegate_fallback};
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace internal
|
} // namespace internal
|
||||||
|
@ -177,9 +177,34 @@ TFL_CAPI_EXPORT extern TfLiteStatus TfLiteInterpreterAllocateTensors(
|
|||||||
|
|
||||||
// Runs inference for the loaded graph.
|
// Runs inference for the loaded graph.
|
||||||
//
|
//
|
||||||
|
// Before calling this function, the caller should first invoke
|
||||||
|
// TfLiteInterpreterAllocateTensors() and should also set the values for the
|
||||||
|
// input tensors. After successfully calling this function, the values for the
|
||||||
|
// output tensors will be set.
|
||||||
|
//
|
||||||
// NOTE: It is possible that the interpreter is not in a ready state to
|
// NOTE: It is possible that the interpreter is not in a ready state to
|
||||||
// evaluate (e.g., if a ResizeInputTensor() has been performed without a call to
|
// evaluate (e.g., if AllocateTensors() hasn't been called, or if a
|
||||||
|
// ResizeInputTensor() has been performed without a subsequent call to
|
||||||
// AllocateTensors()).
|
// AllocateTensors()).
|
||||||
|
//
|
||||||
|
// If the (experimental!) delegate fallback option was enabled in the
|
||||||
|
// interpreter options, then the interpreter will automatically fall back to
|
||||||
|
// not using any delegates if execution with delegates fails. For details, see
|
||||||
|
// TfLiteInterpreterOptionsSetEnableDelegateFallback in c_api_experimental.h.
|
||||||
|
//
|
||||||
|
// Returns one of the following status codes:
|
||||||
|
// - kTfLiteOk: Success. Output is valid.
|
||||||
|
// - kTfLiteDelegateError: Execution with delegates failed, due to a problem
|
||||||
|
// with the delegate(s). If fallback was not enabled, output is invalid.
|
||||||
|
// If fallback was enabled, this return value indicates that fallback
|
||||||
|
// succeeded, the output is valid, and all delegates previously applied to
|
||||||
|
// the interpreter have been undone.
|
||||||
|
// - kTfLiteApplicationError: Same as for kTfLiteDelegateError, except that
|
||||||
|
// the problem was not with the delegate itself, but rather was
|
||||||
|
// due to an incompatibility between the delegate(s) and the
|
||||||
|
// interpreter or model.
|
||||||
|
// - kTfLiteError: Unexpected/runtime failure. Output is invalid.
|
||||||
|
|
||||||
TFL_CAPI_EXPORT extern TfLiteStatus TfLiteInterpreterInvoke(
|
TFL_CAPI_EXPORT extern TfLiteStatus TfLiteInterpreterInvoke(
|
||||||
TfLiteInterpreter* interpreter);
|
TfLiteInterpreter* interpreter);
|
||||||
|
|
||||||
|
@ -77,6 +77,11 @@ void TfLiteInterpreterOptionsSetUseNNAPI(TfLiteInterpreterOptions* options,
|
|||||||
options->use_nnapi = enable;
|
options->use_nnapi = enable;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void TfLiteInterpreterOptionsSetEnableDelegateFallback(
|
||||||
|
TfLiteInterpreterOptions* options, bool enable) {
|
||||||
|
options->enable_delegate_fallback = enable;
|
||||||
|
}
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
} // extern "C"
|
} // extern "C"
|
||||||
#endif // __cplusplus
|
#endif // __cplusplus
|
||||||
|
@ -113,12 +113,39 @@ TFL_CAPI_EXPORT extern TfLiteInterpreter*
|
|||||||
TfLiteInterpreterCreateWithSelectedOps(const TfLiteModel* model,
|
TfLiteInterpreterCreateWithSelectedOps(const TfLiteModel* model,
|
||||||
const TfLiteInterpreterOptions* options);
|
const TfLiteInterpreterOptions* options);
|
||||||
|
|
||||||
/// Enable or disable the NN API for the interpreter (true to enable).
|
/// Enable or disable the NN API delegate for the interpreter (true to enable).
|
||||||
///
|
///
|
||||||
/// WARNING: This is an experimental API and subject to change.
|
/// WARNING: This is an experimental API and subject to change.
|
||||||
TFL_CAPI_EXPORT extern void TfLiteInterpreterOptionsSetUseNNAPI(
|
TFL_CAPI_EXPORT extern void TfLiteInterpreterOptionsSetUseNNAPI(
|
||||||
TfLiteInterpreterOptions* options, bool enable);
|
TfLiteInterpreterOptions* options, bool enable);
|
||||||
|
|
||||||
|
/// Enable or disable CPU fallback for the interpreter (true to enable).
|
||||||
|
/// If enabled, TfLiteInterpreterInvoke will do automatic fallback from
|
||||||
|
/// executing with delegate(s) to regular execution without delegates
|
||||||
|
/// (i.e. on CPU).
|
||||||
|
///
|
||||||
|
/// Allowing the fallback is suitable only if both of the following hold:
|
||||||
|
/// - The caller is known not to cache pointers to tensor data across
|
||||||
|
/// TfLiteInterpreterInvoke calls.
|
||||||
|
/// - The model is not stateful (no variables, no LSTMs) or the state isn't
|
||||||
|
/// needed between batches.
|
||||||
|
///
|
||||||
|
/// When delegate fallback is enabled, TfLiteInterpreterInvoke will
|
||||||
|
/// behave as follows:
|
||||||
|
/// If one or more delegates were set in the interpreter options
|
||||||
|
/// (see TfLiteInterpreterOptionsAddDelegate),
|
||||||
|
/// AND inference fails,
|
||||||
|
/// then the interpreter will fall back to not using any delegates.
|
||||||
|
/// In that case, the previously applied delegate(s) will be automatically
|
||||||
|
/// undone, and an attempt will be made to return the interpreter to an
|
||||||
|
/// invokable state, which may invalidate previous tensor addresses,
|
||||||
|
/// and the inference will be attempted again, using input tensors with
|
||||||
|
/// the same value as previously set.
|
||||||
|
///
|
||||||
|
/// WARNING: This is an experimental API and subject to change.
|
||||||
|
TFL_CAPI_EXPORT extern void TfLiteInterpreterOptionsSetEnableDelegateFallback(
|
||||||
|
TfLiteInterpreterOptions* options, bool enable);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
} // extern "C"
|
} // extern "C"
|
||||||
#endif // __cplusplus
|
#endif // __cplusplus
|
||||||
|
@ -15,12 +15,20 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/lite/c/c_api_experimental.h"
|
#include "tensorflow/lite/c/c_api_experimental.h"
|
||||||
|
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include "tensorflow/lite/builtin_ops.h"
|
#include "tensorflow/lite/builtin_ops.h"
|
||||||
#include "tensorflow/lite/c/c_api.h"
|
#include "tensorflow/lite/c/c_api.h"
|
||||||
#include "tensorflow/lite/c/common.h"
|
#include "tensorflow/lite/c/common.h"
|
||||||
|
#include "tensorflow/lite/delegates/delegate_test_util.h"
|
||||||
#include "tensorflow/lite/testing/util.h"
|
#include "tensorflow/lite/testing/util.h"
|
||||||
|
|
||||||
|
using tflite::delegates::test_utils::TestDelegate;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
const TfLiteRegistration* GetDummyRegistration() {
|
const TfLiteRegistration* GetDummyRegistration() {
|
||||||
@ -159,6 +167,130 @@ TEST(CApiExperimentalTest, SetOpResolver) {
|
|||||||
TfLiteModelDelete(model);
|
TfLiteModelDelete(model);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void AllocateAndSetInputs(TfLiteInterpreter* interpreter) {
|
||||||
|
std::array<int, 1> input_dims = {2};
|
||||||
|
ASSERT_EQ(TfLiteInterpreterResizeInputTensor(
|
||||||
|
interpreter, 0, input_dims.data(), input_dims.size()),
|
||||||
|
kTfLiteOk);
|
||||||
|
ASSERT_EQ(TfLiteInterpreterAllocateTensors(interpreter), kTfLiteOk);
|
||||||
|
TfLiteTensor* input_tensor = TfLiteInterpreterGetInputTensor(interpreter, 0);
|
||||||
|
ASSERT_NE(input_tensor, nullptr);
|
||||||
|
std::array<float, 2> input = {1.f, 3.f};
|
||||||
|
ASSERT_EQ(TfLiteTensorCopyFromBuffer(input_tensor, input.data(),
|
||||||
|
input.size() * sizeof(float)),
|
||||||
|
kTfLiteOk);
|
||||||
|
}
|
||||||
|
|
||||||
|
void VerifyOutputs(TfLiteInterpreter* interpreter) {
|
||||||
|
const TfLiteTensor* output_tensor =
|
||||||
|
TfLiteInterpreterGetOutputTensor(interpreter, 0);
|
||||||
|
ASSERT_NE(output_tensor, nullptr);
|
||||||
|
std::array<float, 2> output;
|
||||||
|
ASSERT_EQ(TfLiteTensorCopyToBuffer(output_tensor, output.data(),
|
||||||
|
output.size() * sizeof(float)),
|
||||||
|
kTfLiteOk);
|
||||||
|
EXPECT_EQ(output[0], 3.f);
|
||||||
|
EXPECT_EQ(output[1], 9.f);
|
||||||
|
}
|
||||||
|
|
||||||
|
void CheckExecution(TfLiteInterpreterOptions* options,
|
||||||
|
TfLiteStatus expected_first_result,
|
||||||
|
TfLiteStatus expected_subsequent_results) {
|
||||||
|
TfLiteModel* model =
|
||||||
|
TfLiteModelCreateFromFile("tensorflow/lite/testdata/add.bin");
|
||||||
|
ASSERT_NE(model, nullptr);
|
||||||
|
|
||||||
|
TfLiteInterpreter* interpreter = TfLiteInterpreterCreate(model, options);
|
||||||
|
ASSERT_NE(interpreter, nullptr);
|
||||||
|
|
||||||
|
AllocateAndSetInputs(interpreter);
|
||||||
|
for (int i = 0; i < 4; i++) {
|
||||||
|
bool result = TfLiteInterpreterInvoke(interpreter);
|
||||||
|
bool expected_result =
|
||||||
|
((i == 0) ? expected_first_result : expected_subsequent_results);
|
||||||
|
EXPECT_EQ(result, expected_result);
|
||||||
|
if (result != kTfLiteError) {
|
||||||
|
VerifyOutputs(interpreter);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteInterpreterDelete(interpreter);
|
||||||
|
TfLiteModelDelete(model);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TestDelegate, NoDelegate) {
|
||||||
|
TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate();
|
||||||
|
// Execution without any delegate should succeed.
|
||||||
|
CheckExecution(options, kTfLiteOk, kTfLiteOk);
|
||||||
|
TfLiteInterpreterOptionsDelete(options);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TestDelegate, DelegateNodeInvokeFailure) {
|
||||||
|
// Initialize a delegate that will fail when invoked.
|
||||||
|
delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate(
|
||||||
|
{0, 1}, kTfLiteDelegateFlagsNone, false /**fail_node_prepare**/,
|
||||||
|
0 /**min_ops_per_subset**/, true /**fail_node_invoke**/,
|
||||||
|
false /**automatic_shape_propagation**/, false /**custom_op**/));
|
||||||
|
// Create another interpreter with the delegate, without fallback.
|
||||||
|
TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate();
|
||||||
|
TfLiteInterpreterOptionsAddDelegate(options,
|
||||||
|
delegate_->get_tf_lite_delegate());
|
||||||
|
// Execution with the delegate should fail.
|
||||||
|
CheckExecution(options, kTfLiteError, kTfLiteError);
|
||||||
|
TfLiteInterpreterOptionsDelete(options);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TestDelegate, DelegateNodeInvokeFailureFallback) {
|
||||||
|
// Initialize a delegate that will fail when invoked.
|
||||||
|
delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate(
|
||||||
|
{0, 1}, kTfLiteDelegateFlagsNone, false /**fail_node_prepare**/,
|
||||||
|
0 /**min_ops_per_subset**/, true /**fail_node_invoke**/,
|
||||||
|
false /**automatic_shape_propagation**/, false /**custom_op**/));
|
||||||
|
// Create another interpreter with the delegate, with fallback enabled.
|
||||||
|
TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate();
|
||||||
|
TfLiteInterpreterOptionsAddDelegate(options,
|
||||||
|
delegate_->get_tf_lite_delegate());
|
||||||
|
TfLiteInterpreterOptionsSetEnableDelegateFallback(options, true);
|
||||||
|
CheckExecution(options,
|
||||||
|
// First execution will report DelegateError which indicates
|
||||||
|
// that the delegate failed but fallback succeeded.
|
||||||
|
kTfLiteDelegateError,
|
||||||
|
// Subsequent executions will not use the delegate and
|
||||||
|
// should therefore succeed.
|
||||||
|
kTfLiteOk);
|
||||||
|
TfLiteInterpreterOptionsDelete(options);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TestDelegate, TestFallbackWithMultipleDelegates) {
|
||||||
|
// First delegate only supports node 0.
|
||||||
|
// This delegate should support dynamic tensors, otherwise the second won't be
|
||||||
|
// applied.
|
||||||
|
delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate(
|
||||||
|
{0}, kTfLiteDelegateFlagsAllowDynamicTensors,
|
||||||
|
false /**fail_node_prepare**/, 0 /**min_ops_per_subset**/,
|
||||||
|
true /**fail_node_invoke**/, false /**automatic_shape_propagation**/,
|
||||||
|
false /**custom_op**/));
|
||||||
|
// Second delegate supports node 1, and makes the graph immutable.
|
||||||
|
delegate2_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate(
|
||||||
|
{1}, kTfLiteDelegateFlagsNone, false /**fail_node_prepare**/,
|
||||||
|
0 /**min_ops_per_subset**/, true /**fail_node_invoke**/,
|
||||||
|
false /**automatic_shape_propagation**/, false /**custom_op**/));
|
||||||
|
TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate();
|
||||||
|
TfLiteInterpreterOptionsAddDelegate(options,
|
||||||
|
delegate_->get_tf_lite_delegate());
|
||||||
|
TfLiteInterpreterOptionsAddDelegate(options,
|
||||||
|
delegate2_->get_tf_lite_delegate());
|
||||||
|
TfLiteInterpreterOptionsSetEnableDelegateFallback(options, true);
|
||||||
|
CheckExecution(options,
|
||||||
|
// First execution will report DelegateError which indicates
|
||||||
|
// that the delegate failed but fallback succeeded.
|
||||||
|
kTfLiteDelegateError,
|
||||||
|
// Subsequent executions will not use the delegate and
|
||||||
|
// should therefore succeed.
|
||||||
|
kTfLiteOk);
|
||||||
|
TfLiteInterpreterOptionsDelete(options);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
int main(int argc, char** argv) {
|
int main(int argc, char** argv) {
|
||||||
|
@ -87,6 +87,12 @@ struct TfLiteInterpreterOptions {
|
|||||||
TfLiteErrorReporterCallback error_reporter_callback;
|
TfLiteErrorReporterCallback error_reporter_callback;
|
||||||
|
|
||||||
bool use_nnapi = false;
|
bool use_nnapi = false;
|
||||||
|
|
||||||
|
// Determines whether to allow automatic fallback to CPU.
|
||||||
|
// If true, and if one or more delegates were set,
|
||||||
|
// then if Invoke with delegates fails, it will be
|
||||||
|
// automatically retried without delegates.
|
||||||
|
bool enable_delegate_fallback = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct TfLiteInterpreter {
|
struct TfLiteInterpreter {
|
||||||
@ -100,6 +106,8 @@ struct TfLiteInterpreter {
|
|||||||
std::unique_ptr<tflite::ErrorReporter> optional_error_reporter;
|
std::unique_ptr<tflite::ErrorReporter> optional_error_reporter;
|
||||||
|
|
||||||
std::unique_ptr<tflite::Interpreter> impl;
|
std::unique_ptr<tflite::Interpreter> impl;
|
||||||
|
|
||||||
|
bool enable_delegate_fallback;
|
||||||
};
|
};
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
|
@ -113,17 +113,16 @@ void TestDelegate::TearDown() {
|
|||||||
delegate_.reset();
|
delegate_.reset();
|
||||||
}
|
}
|
||||||
|
|
||||||
TestDelegate::SimpleDelegate::SimpleDelegate(const std::vector<int>& nodes,
|
TestDelegate::SimpleDelegate::SimpleDelegate(
|
||||||
int64_t delegate_flags,
|
const std::vector<int>& nodes, int64_t delegate_flags,
|
||||||
bool fail_node_prepare,
|
bool fail_node_prepare, int min_ops_per_subset, bool fail_node_invoke,
|
||||||
int min_ops_per_subset,
|
bool automatic_shape_propagation, bool custom_op)
|
||||||
bool fail_node_invoke,
|
|
||||||
bool automatic_shape_propagation)
|
|
||||||
: nodes_(nodes),
|
: nodes_(nodes),
|
||||||
fail_delegate_node_prepare_(fail_node_prepare),
|
fail_delegate_node_prepare_(fail_node_prepare),
|
||||||
min_ops_per_subset_(min_ops_per_subset),
|
min_ops_per_subset_(min_ops_per_subset),
|
||||||
fail_delegate_node_invoke_(fail_node_invoke),
|
fail_delegate_node_invoke_(fail_node_invoke),
|
||||||
automatic_shape_propagation_(automatic_shape_propagation) {
|
automatic_shape_propagation_(automatic_shape_propagation),
|
||||||
|
custom_op_(custom_op) {
|
||||||
delegate_.Prepare = [](TfLiteContext* context,
|
delegate_.Prepare = [](TfLiteContext* context,
|
||||||
TfLiteDelegate* delegate) -> TfLiteStatus {
|
TfLiteDelegate* delegate) -> TfLiteStatus {
|
||||||
auto* simple = static_cast<SimpleDelegate*>(delegate->data_);
|
auto* simple = static_cast<SimpleDelegate*>(delegate->data_);
|
||||||
@ -137,8 +136,12 @@ TestDelegate::SimpleDelegate::SimpleDelegate(const std::vector<int>& nodes,
|
|||||||
TfLiteNode* node;
|
TfLiteNode* node;
|
||||||
TfLiteRegistration* reg;
|
TfLiteRegistration* reg;
|
||||||
context->GetNodeAndRegistration(context, node_index, &node, ®);
|
context->GetNodeAndRegistration(context, node_index, &node, ®);
|
||||||
|
if (simple->custom_op_) {
|
||||||
TFLITE_CHECK_EQ(reg->builtin_code, tflite::BuiltinOperator_CUSTOM);
|
TFLITE_CHECK_EQ(reg->builtin_code, tflite::BuiltinOperator_CUSTOM);
|
||||||
TFLITE_CHECK_EQ(strcmp(reg->custom_name, "my_add"), 0);
|
TFLITE_CHECK_EQ(strcmp(reg->custom_name, "my_add"), 0);
|
||||||
|
} else {
|
||||||
|
TFLITE_CHECK_EQ(reg->builtin_code, tflite::BuiltinOperator_ADD);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// Check that all nodes are available
|
// Check that all nodes are available
|
||||||
TfLiteIntArray* execution_plan;
|
TfLiteIntArray* execution_plan;
|
||||||
@ -150,8 +153,12 @@ TestDelegate::SimpleDelegate::SimpleDelegate(const std::vector<int>& nodes,
|
|||||||
context->GetNodeAndRegistration(context, node_index, &node, ®);
|
context->GetNodeAndRegistration(context, node_index, &node, ®);
|
||||||
if (exec_index == node_index) {
|
if (exec_index == node_index) {
|
||||||
// Check op details only if it wasn't delegated already.
|
// Check op details only if it wasn't delegated already.
|
||||||
|
if (simple->custom_op_) {
|
||||||
TFLITE_CHECK_EQ(reg->builtin_code, tflite::BuiltinOperator_CUSTOM);
|
TFLITE_CHECK_EQ(reg->builtin_code, tflite::BuiltinOperator_CUSTOM);
|
||||||
TFLITE_CHECK_EQ(strcmp(reg->custom_name, "my_add"), 0);
|
TFLITE_CHECK_EQ(strcmp(reg->custom_name, "my_add"), 0);
|
||||||
|
} else {
|
||||||
|
TFLITE_CHECK_EQ(reg->builtin_code, tflite::BuiltinOperator_ADD);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -182,7 +189,7 @@ TestDelegate::SimpleDelegate::SimpleDelegate(const std::vector<int>& nodes,
|
|||||||
sizeof(int) * nodes_to_separate->size);
|
sizeof(int) * nodes_to_separate->size);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Another call to PreviewDelegateParitioning should be okay, since
|
// Another call to PreviewDelegatePartitioning should be okay, since
|
||||||
// partitioning memory is managed by context.
|
// partitioning memory is managed by context.
|
||||||
TFLITE_CHECK_EQ(
|
TFLITE_CHECK_EQ(
|
||||||
context->PreviewDelegatePartitioning(context, nodes_to_separate,
|
context->PreviewDelegatePartitioning(context, nodes_to_separate,
|
||||||
|
@ -54,18 +54,25 @@ class TestDelegate : public ::testing::Test {
|
|||||||
// Create a simple implementation of a TfLiteDelegate. We use the C++ class
|
// Create a simple implementation of a TfLiteDelegate. We use the C++ class
|
||||||
// SimpleDelegate and it can produce a handle TfLiteDelegate that is
|
// SimpleDelegate and it can produce a handle TfLiteDelegate that is
|
||||||
// value-copyable and compatible with TfLite.
|
// value-copyable and compatible with TfLite.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// nodes: Indices of the graph nodes that the delegate will handle.
|
||||||
// fail_node_prepare: To simulate failure of Delegate node's Prepare().
|
// fail_node_prepare: To simulate failure of Delegate node's Prepare().
|
||||||
// min_ops_per_subset: If >0, partitioning preview is used to choose only
|
// min_ops_per_subset: If >0, partitioning preview is used to choose only
|
||||||
// those subsets with min_ops_per_subset number of nodes.
|
// those subsets with min_ops_per_subset number of nodes.
|
||||||
// fail_node_invoke: To simulate failure of Delegate node's Invoke().
|
// fail_node_invoke: To simulate failure of Delegate node's Invoke().
|
||||||
// automatic_shape_propagation: This assumes that the runtime will propagate
|
// automatic_shape_propagation: This assumes that the runtime will
|
||||||
// shapes using the original execution plan.
|
// propagate shapes using the original execution plan.
|
||||||
|
// custom_op: If true, the graph nodes specified in the 'nodes' parameter
|
||||||
|
// should be custom ops with name "my_add"; if false, they should be
|
||||||
|
// the builtin ADD operator.
|
||||||
explicit SimpleDelegate(const std::vector<int>& nodes,
|
explicit SimpleDelegate(const std::vector<int>& nodes,
|
||||||
int64_t delegate_flags = kTfLiteDelegateFlagsNone,
|
int64_t delegate_flags = kTfLiteDelegateFlagsNone,
|
||||||
bool fail_node_prepare = false,
|
bool fail_node_prepare = false,
|
||||||
int min_ops_per_subset = 0,
|
int min_ops_per_subset = 0,
|
||||||
bool fail_node_invoke = false,
|
bool fail_node_invoke = false,
|
||||||
bool automatic_shape_propagation = false);
|
bool automatic_shape_propagation = false,
|
||||||
|
bool custom_op = true);
|
||||||
|
|
||||||
TfLiteRegistration FakeFusedRegistration();
|
TfLiteRegistration FakeFusedRegistration();
|
||||||
|
|
||||||
@ -80,6 +87,7 @@ class TestDelegate : public ::testing::Test {
|
|||||||
int min_ops_per_subset_ = 0;
|
int min_ops_per_subset_ = 0;
|
||||||
bool fail_delegate_node_invoke_ = false;
|
bool fail_delegate_node_invoke_ = false;
|
||||||
bool automatic_shape_propagation_ = false;
|
bool automatic_shape_propagation_ = false;
|
||||||
|
bool custom_op_ = true;
|
||||||
};
|
};
|
||||||
|
|
||||||
std::unique_ptr<Interpreter> interpreter_;
|
std::unique_ptr<Interpreter> interpreter_;
|
||||||
|
@ -137,6 +137,7 @@ $(wildcard tensorflow/lite/c/*.c) \
|
|||||||
$(wildcard tensorflow/lite/c/*.cc) \
|
$(wildcard tensorflow/lite/c/*.cc) \
|
||||||
$(wildcard tensorflow/lite/core/*.cc) \
|
$(wildcard tensorflow/lite/core/*.cc) \
|
||||||
$(wildcard tensorflow/lite/core/api/*.cc) \
|
$(wildcard tensorflow/lite/core/api/*.cc) \
|
||||||
|
$(wildcard tensorflow/lite/delegates/interpreter_utils.cc) \
|
||||||
$(wildcard tensorflow/lite/experimental/resource/*.cc) \
|
$(wildcard tensorflow/lite/experimental/resource/*.cc) \
|
||||||
$(wildcard tensorflow/lite/schema/schema_utils.cc) \
|
$(wildcard tensorflow/lite/schema/schema_utils.cc) \
|
||||||
$(wildcard tensorflow/lite/tools/make/downloads/ruy/ruy/*.cc)
|
$(wildcard tensorflow/lite/tools/make/downloads/ruy/ruy/*.cc)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user