diff --git a/tensorflow/lite/interpreter.cc b/tensorflow/lite/interpreter.cc index cae2ca7dde0..b49aa5031bf 100644 --- a/tensorflow/lite/interpreter.cc +++ b/tensorflow/lite/interpreter.cc @@ -86,8 +86,9 @@ TfLiteQuantization GetQuantizationFromLegacy( } // namespace Interpreter::Interpreter(ErrorReporter* error_reporter) - : error_reporter_(error_reporter ? error_reporter - : DefaultErrorReporter()) { + : error_reporter_(error_reporter ? error_reporter : DefaultErrorReporter()), + lazy_delegate_provider_( + TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {})) { // TODO(b/128420794): Include the TFLite runtime version in the log. // Prod logging is useful for mobile platforms where scraping console logs is // critical for debugging. @@ -175,6 +176,16 @@ TfLiteStatus Interpreter::SetVariables(std::vector variables) { } TfLiteStatus Interpreter::AllocateTensors() { + // Apply the default delegate that TFLite will enable at this point to allow + // other user-level delegates to be applied first. + if (lazy_delegate_provider_) { + // The execution will fall back to default implementation if the XNNPACK + // delegate fails to be applied. Therefore, we ignore the return status + // here and let it fall through the rest of the code. + ModifyGraphWithDelegate(std::move(lazy_delegate_provider_)); + lazy_delegate_provider_.reset(); + } + return primary_subgraph().AllocateTensors(); } diff --git a/tensorflow/lite/interpreter.h b/tensorflow/lite/interpreter.h index 59cab6add6d..41377c4ce1f 100644 --- a/tensorflow/lite/interpreter.h +++ b/tensorflow/lite/interpreter.h @@ -347,10 +347,12 @@ class Interpreter { /// WARNING: Experimental interface, subject to change TfLiteStatus ReleaseNonPersistentMemory(); - /// Update allocations for all tensors. This will redim dependent tensors - /// using the input tensor dimensionality as given. This is relatively - /// expensive. If you know that your sizes are not changing, you need not call - /// this. Returns status of success or failure. + // Update allocations for all tensors. This will redim dependent tensors + // using the input tensor dimensionality as given. This is relatively + // expensive. This *must be* called after the interpreter has been created + // and before running inference (and accessing tensor buffers), and *must be* + // called again if (and only if) an input tensor is resized. Returns status of + // success or failure. TfLiteStatus AllocateTensors(); /// Invoke the interpreter (run the whole graph in dependency order). @@ -594,6 +596,11 @@ class Interpreter { // A map of resources. Owned by interpreter and shared by multiple subgraphs. resource::ResourceMap resources_; + + // Indicating a delegate that the TFLite interpreter will apply by default. + // A nullptr value means there's no delegate to be applied by default or the + // delegate has been applied and doesn't need to be applied again. + TfLiteDelegatePtr lazy_delegate_provider_; }; } // namespace impl diff --git a/tensorflow/lite/interpreter_builder.cc b/tensorflow/lite/interpreter_builder.cc index d73b298e595..4b491d41881 100644 --- a/tensorflow/lite/interpreter_builder.cc +++ b/tensorflow/lite/interpreter_builder.cc @@ -545,17 +545,7 @@ TfLiteStatus InterpreterBuilder::ParseTensors( TfLiteStatus InterpreterBuilder::ApplyDelegates(Interpreter* interpreter, int num_threads) { - // First, apply XNNPACK delegate if applicable. - if (num_fp32_tensors_ > 0) { - // The execution will fall back to default implementation if the XNNPACK - // delegate fails to be applied. Therefore, we ignore the return status - // here and let it fall through the rest of the code. - if (auto xnnpack_delegate = MaybeCreateXNNPACKDelegate(num_threads)) { - interpreter->ModifyGraphWithDelegate(std::move(xnnpack_delegate)); - } - } - - // Secondly, apply Flex delegate if applicable. + // Apply Flex delegate if applicable. if (has_flex_op_) { if (auto flex_delegate = AcquireFlexDelegate()) { return interpreter->ModifyGraphWithDelegate(std::move(flex_delegate)); @@ -672,6 +662,11 @@ TfLiteStatus InterpreterBuilder::operator()( modified_subgraph->SetVariables(std::move(variables)); } + if (num_fp32_tensors_ > 0) { + (*interpreter)->lazy_delegate_provider_ = + MaybeCreateXNNPACKDelegate(num_threads); + } + if (ApplyDelegates(interpreter->get(), num_threads) != kTfLiteOk) return cleanup_and_error();