diff --git a/tensorflow/lite/BUILD b/tensorflow/lite/BUILD index c84972ea027..e80e32fe6cf 100644 --- a/tensorflow/lite/BUILD +++ b/tensorflow/lite/BUILD @@ -254,7 +254,6 @@ cc_library( ":shared_library", ":simple_memory_arena", ":string", - ":tflite_with_xnnpack_optional", ":type_to_tflitetype", ":util", ":version", diff --git a/tensorflow/lite/c/common.h b/tensorflow/lite/c/common.h index 23eb528f4c9..18997298df7 100644 --- a/tensorflow/lite/c/common.h +++ b/tensorflow/lite/c/common.h @@ -47,7 +47,8 @@ extern "C" { typedef enum TfLiteStatus { kTfLiteOk = 0, kTfLiteError = 1, - kTfLiteDelegateError = 2 + kTfLiteDelegateError = 2, + kTfLiteApplicationError = 3 } TfLiteStatus; // The list of external context types known to TF Lite. This list exists solely diff --git a/tensorflow/lite/core/api/op_resolver.h b/tensorflow/lite/core/api/op_resolver.h index 1294b7b8ea8..b6a8171d2a3 100644 --- a/tensorflow/lite/core/api/op_resolver.h +++ b/tensorflow/lite/core/api/op_resolver.h @@ -15,6 +15,8 @@ limitations under the License. #ifndef TENSORFLOW_LITE_CORE_API_OP_RESOLVER_H_ #define TENSORFLOW_LITE_CORE_API_OP_RESOLVER_H_ +#include + #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/schema/schema_generated.h" @@ -32,6 +34,16 @@ class OpResolver { /// Finds the op registration of a custom operator by op name. virtual const TfLiteRegistration* FindOp(const char* op, int version) const = 0; + + // Returns optional delegates for resolving and handling ops in the flatbuffer + // model. This may be used in addition to the standard TfLiteRegistration + // lookup for graph resolution. + using TfLiteDelegatePtrVector = + std::vector>; + virtual TfLiteDelegatePtrVector GetDelegates(int num_threads) const { + return TfLiteDelegatePtrVector(); + } + virtual ~OpResolver() {} }; diff --git a/tensorflow/lite/core/subgraph.cc b/tensorflow/lite/core/subgraph.cc index 15b8a0bcc57..ecdb04c8b3c 100644 --- a/tensorflow/lite/core/subgraph.cc +++ b/tensorflow/lite/core/subgraph.cc @@ -1414,7 +1414,7 @@ TfLiteStatus Subgraph::ModifyGraphWithDelegate(TfLiteDelegate* delegate) { if (state_ == kStateInvokableAndImmutable) { ReportError( "ModifyGraphWithDelegate is disallowed when graph is immutable."); - return kTfLiteError; + return kTfLiteApplicationError; } if (!(delegate->flags & kTfLiteDelegateFlagsAllowDynamicTensors)) { diff --git a/tensorflow/lite/core/subgraph.h b/tensorflow/lite/core/subgraph.h index 1fe1c7e4391..3a28b4cb99c 100644 --- a/tensorflow/lite/core/subgraph.h +++ b/tensorflow/lite/core/subgraph.h @@ -558,12 +558,15 @@ class Subgraph { // be reallocated if the graph was modified (i.e., the caller does *not* need // to explicitly call |AllocateTensors()| again). If tensors were unallocated, // they will remain unallocated after delegate application. - // Returns one of the following three status codes: + // Returns one of the following status codes: // 1. kTfLiteOk: Delegation succeeded - // 2. kTfLiteDelegateError: Delegation failed due to an error in the - // delegate. The Subgraph has been restored to its pre-delegation state. + // 2. kTfLiteDelegateError: Delegation failed due to an error *in the + // delegate*. The Subgraph has been restored to its pre-delegation state. // NOTE: This reverts all delegates previously applied to the Subgraph. - // 3. kTfLiteError: Unexpected/runtime failure. + // 3. kTfLiteApplicationError : Delegation failed to be applied due to the + // state that the TfLite runtime is in. However, the Subgraph is still in a + // invokable state. + // 4. kTfLiteError: Unexpected/runtime failure. TfLiteStatus ModifyGraphWithDelegate(TfLiteDelegate* delegate); // This un-applies all delegates that have been applied till now, but retains diff --git a/tensorflow/lite/interpreter.cc b/tensorflow/lite/interpreter.cc index 4f81824d96f..a79ea86f61e 100644 --- a/tensorflow/lite/interpreter.cc +++ b/tensorflow/lite/interpreter.cc @@ -86,9 +86,8 @@ TfLiteQuantization GetQuantizationFromLegacy( } // namespace Interpreter::Interpreter(ErrorReporter* error_reporter) - : error_reporter_(error_reporter ? error_reporter : DefaultErrorReporter()), - lazy_delegate_provider_( - TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {})) { + : error_reporter_(error_reporter ? error_reporter + : DefaultErrorReporter()) { // TODO(b/128420794): Include the TFLite runtime version in the log. // Prod logging is useful for mobile platforms where scraping console logs is // critical for debugging. @@ -184,21 +183,53 @@ TfLiteStatus Interpreter::SetVariables(std::vector variables) { TfLiteStatus Interpreter::AllocateTensors() { // Apply the default delegate that TFLite will enable at this point to allow // other user-level delegates to be applied first. - if (lazy_delegate_provider_) { - // The execution will fall back to default implementation if the XNNPACK - // delegate fails to be applied. Therefore, we ignore the return status - // here and let it fall through the rest of the code. - auto status = ModifyGraphWithDelegate(std::move(lazy_delegate_provider_)); - if (status != kTfLiteOk) { - TF_LITE_REPORT_ERROR( - error_reporter_, - "Ignoring failed application of the default TensorFlow Lite " - "delegate."); - } else { - TFLITE_LOG(TFLITE_LOG_INFO, - "Successfully applied the default TensorFlow Lite delegate."); + if (!lazy_delegate_providers_.empty()) { + TFLITE_LOG(TFLITE_LOG_INFO, + "Applying %zu TensorFlow Lite delegate(s) lazily.", + lazy_delegate_providers_.size()); + // At the momement, XNNPACK delegate is the only one that might be applied + // by default, in which case, the execution will fall back to default + // implementation if the XNNPACK delegate fails to be applied. Therefore, we + // ignore the return status here and let it fall through the rest of the + // code. + for (size_t i = 0; i < lazy_delegate_providers_.size(); ++i) { + auto status = + ModifyGraphWithDelegate(std::move(lazy_delegate_providers_[i])); + switch (status) { + case kTfLiteOk: + TFLITE_LOG(TFLITE_LOG_INFO, + "Successfully applied the default TensorFlow Lite " + "delegate indexed at %zu.", + i); + break; + case kTfLiteError: + TF_LITE_REPORT_ERROR(error_reporter_, + "Failed to apply the default TensorFlow Lite " + "delegate indexed at %zu.", + i); + return kTfLiteError; + case kTfLiteDelegateError: + TF_LITE_REPORT_ERROR( + error_reporter_, + "Error in applying the default TensorFlow Lite delegate indexed " + "at %zu, and all previously applied delegates are reverted.", + i); + break; + case kTfLiteApplicationError: + TF_LITE_REPORT_ERROR(error_reporter_, + "Ignoring failed application of the default " + "TensorFlow Lite delegate indexed at %zu.", + i); + break; + default: + TF_LITE_REPORT_ERROR(error_reporter_, + "Unknown status (%d) after applying the default " + "TensorFlow Lite delegate indexed at %zu.", + status, i); + return kTfLiteError; + } } - lazy_delegate_provider_.reset(); + lazy_delegate_providers_.clear(); } return primary_subgraph().AllocateTensors(); diff --git a/tensorflow/lite/interpreter.h b/tensorflow/lite/interpreter.h index d4bf3016810..f27a17dfafe 100644 --- a/tensorflow/lite/interpreter.h +++ b/tensorflow/lite/interpreter.h @@ -653,10 +653,10 @@ class Interpreter { // A map of resources. Owned by interpreter and shared by multiple subgraphs. resource::ResourceMap resources_; - // Indicating a delegate that the TFLite interpreter will apply by default. - // A nullptr value means there's no delegate to be applied by default or the - // delegate has been applied and doesn't need to be applied again. - TfLiteDelegatePtr lazy_delegate_provider_; + // Indicating delegates that the TFLite interpreter will apply by default. + // An empty one means there's no delegate to be applied by default or + // delegates have been applied and doesn't need to be applied again. + std::vector lazy_delegate_providers_; }; } // namespace impl diff --git a/tensorflow/lite/interpreter_builder.cc b/tensorflow/lite/interpreter_builder.cc index 07c5251fab3..0765f00faf3 100644 --- a/tensorflow/lite/interpreter_builder.cc +++ b/tensorflow/lite/interpreter_builder.cc @@ -29,7 +29,6 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/compatibility.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/shared_library.h" -#include "tensorflow/lite/tflite_with_xnnpack_optional.h" #include "tensorflow/lite/util.h" #include "tensorflow/lite/version.h" @@ -675,8 +674,8 @@ TfLiteStatus InterpreterBuilder::operator()( } if (num_fp32_tensors_ > 0) { - (*interpreter)->lazy_delegate_provider_ = - MaybeCreateXNNPACKDelegate(num_threads); + (*interpreter)->lazy_delegate_providers_ = + op_resolver_.GetDelegates(num_threads); } if (ApplyDelegates(interpreter->get(), num_threads) != kTfLiteOk) diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index a56d370afeb..9a672dfa89d 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -757,6 +757,7 @@ cc_library( deps = [ ":builtin_op_kernels", "//tensorflow/lite:framework", + "//tensorflow/lite:tflite_with_xnnpack_optional", "//tensorflow/lite/c:common", "//tensorflow/lite/schema:schema_fbs", ], @@ -774,6 +775,7 @@ cc_library( deps = [ ":builtin_op_kernels", "//tensorflow/lite:framework_lib", + "//tensorflow/lite:tflite_with_xnnpack_optional", "//tensorflow/lite/c:common", "//tensorflow/lite/schema:schema_fbs", ], @@ -791,6 +793,7 @@ cc_library( deps = [ ":builtin_op_kernels_ruy_and_caching", "//tensorflow/lite:framework_lib", + "//tensorflow/lite:tflite_with_xnnpack_optional", "//tensorflow/lite/c:common", "//tensorflow/lite/schema:schema_fbs", ], diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc index 3c16bfd097d..e020298fc8f 100644 --- a/tensorflow/lite/kernels/register.cc +++ b/tensorflow/lite/kernels/register.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/builtin_op_kernels.h" #include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/tflite_with_xnnpack_optional.h" namespace tflite { namespace ops { @@ -303,6 +304,21 @@ BuiltinOpResolver::BuiltinOpResolver() { tflite::ops::custom::Register_DETECTION_POSTPROCESS()); } +OpResolver::TfLiteDelegatePtrVector BuiltinOpResolver::GetDelegates( + int num_threads) const { + OpResolver::TfLiteDelegatePtrVector delegates; + auto xnnpack_delegate = tflite::MaybeCreateXNNPACKDelegate(num_threads); + if (xnnpack_delegate != nullptr) { + delegates.push_back(std::move(xnnpack_delegate)); + } + return delegates; +} + +OpResolver::TfLiteDelegatePtrVector +BuiltinOpResolverWithoutDefaultDelegates::GetDelegates(int num_threads) const { + return OpResolver::TfLiteDelegatePtrVector(); +} + } // namespace builtin } // namespace ops } // namespace tflite diff --git a/tensorflow/lite/kernels/register.h b/tensorflow/lite/kernels/register.h index a2a41ea9428..1a6095c7140 100644 --- a/tensorflow/lite/kernels/register.h +++ b/tensorflow/lite/kernels/register.h @@ -22,9 +22,22 @@ namespace tflite { namespace ops { namespace builtin { +// This built-in op resolver provides a list of TfLite delegates that could be +// applied by TfLite interpreter by default. class BuiltinOpResolver : public MutableOpResolver { public: BuiltinOpResolver(); + OpResolver::TfLiteDelegatePtrVector GetDelegates( + int num_threads) const override; +}; + +// TfLite interpreter could apply a TfLite delegate by default. To completely +// disable this behavior, one could choose to use the following class +// BuiltinOpResolverWithoutDefaultDelegates. +class BuiltinOpResolverWithoutDefaultDelegates : public BuiltinOpResolver { + public: + BuiltinOpResolverWithoutDefaultDelegates() : BuiltinOpResolver() {} + OpResolver::TfLiteDelegatePtrVector GetDelegates(int num_threads) const final; }; } // namespace builtin diff --git a/tensorflow/lite/model_xnnpack_test.cc b/tensorflow/lite/model_xnnpack_test.cc index 73860807c00..f04334c7711 100644 --- a/tensorflow/lite/model_xnnpack_test.cc +++ b/tensorflow/lite/model_xnnpack_test.cc @@ -30,7 +30,7 @@ TEST(FloatModel, WithXnnpackDelegate) { std::unique_ptr interpreter; ASSERT_EQ(InterpreterBuilder(*model, - ops::builtin::BuiltinOpResolver{})(&interpreter), + ops::builtin::BuiltinOpResolver())(&interpreter), kTfLiteOk); ASSERT_TRUE(interpreter); @@ -48,4 +48,32 @@ TEST(FloatModel, WithXnnpackDelegate) { #endif } +TEST(FloatModel, DefaultXnnpackDelegateNotAllowed) { + // Note: this graph will be fully delegated by the XNNPACK delegate. + auto model = FlatBufferModel::BuildFromFile( + "tensorflow/lite/testdata/multi_add.bin"); + ASSERT_TRUE(model); + + std::unique_ptr interpreter; + ASSERT_EQ( + InterpreterBuilder( + *model, ops::builtin::BuiltinOpResolverWithoutDefaultDelegates())( + &interpreter), + kTfLiteOk); + ASSERT_TRUE(interpreter); + + ASSERT_EQ(interpreter->AllocateTensors(), kTfLiteOk); + +#if TFLITE_HAS_ATTRIBUTE_WEAK || defined(TFLITE_BUILD_WITH_XNNPACK_DELEGATE) + // As we don't allow applying xnnpack delegate by default, we will expect the + // following: + EXPECT_LT(1, interpreter->execution_plan().size()); + int first_node_id = interpreter->execution_plan()[0]; + const auto& first_node_reg = + interpreter->node_and_registration(first_node_id)->second; + const std::string op_name = GetOpNameByRegistration(first_node_reg); + EXPECT_EQ("ADD", op_name); +#endif +} + } // namespace tflite diff --git a/tensorflow/lite/tools/benchmark/experimental/c/c_api_types.h b/tensorflow/lite/tools/benchmark/experimental/c/c_api_types.h index 23eb528f4c9..18997298df7 100644 --- a/tensorflow/lite/tools/benchmark/experimental/c/c_api_types.h +++ b/tensorflow/lite/tools/benchmark/experimental/c/c_api_types.h @@ -47,7 +47,8 @@ extern "C" { typedef enum TfLiteStatus { kTfLiteOk = 0, kTfLiteError = 1, - kTfLiteDelegateError = 2 + kTfLiteDelegateError = 2, + kTfLiteApplicationError = 3 } TfLiteStatus; // The list of external context types known to TF Lite. This list exists solely