Use BuiltinOpResolver as a way to apply the xnnpack delegate by default in TfLite interpreter. Also, provide another builtin-op resolver to disallow applying the delegate by default.
PiperOrigin-RevId: 327378746 Change-Id: I801790cf2878875fcf23c4781306e8243c8fd0af
This commit is contained in:
parent
7c8b6efc14
commit
4bb7dca446
@ -254,7 +254,6 @@ cc_library(
|
||||
":shared_library",
|
||||
":simple_memory_arena",
|
||||
":string",
|
||||
":tflite_with_xnnpack_optional",
|
||||
":type_to_tflitetype",
|
||||
":util",
|
||||
":version",
|
||||
|
@ -47,7 +47,8 @@ extern "C" {
|
||||
typedef enum TfLiteStatus {
|
||||
kTfLiteOk = 0,
|
||||
kTfLiteError = 1,
|
||||
kTfLiteDelegateError = 2
|
||||
kTfLiteDelegateError = 2,
|
||||
kTfLiteApplicationError = 3
|
||||
} TfLiteStatus;
|
||||
|
||||
// The list of external context types known to TF Lite. This list exists solely
|
||||
|
@ -15,6 +15,8 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_LITE_CORE_API_OP_RESOLVER_H_
|
||||
#define TENSORFLOW_LITE_CORE_API_OP_RESOLVER_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/core/api/error_reporter.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
@ -32,6 +34,16 @@ class OpResolver {
|
||||
/// Finds the op registration of a custom operator by op name.
|
||||
virtual const TfLiteRegistration* FindOp(const char* op,
|
||||
int version) const = 0;
|
||||
|
||||
// Returns optional delegates for resolving and handling ops in the flatbuffer
|
||||
// model. This may be used in addition to the standard TfLiteRegistration
|
||||
// lookup for graph resolution.
|
||||
using TfLiteDelegatePtrVector =
|
||||
std::vector<std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>>;
|
||||
virtual TfLiteDelegatePtrVector GetDelegates(int num_threads) const {
|
||||
return TfLiteDelegatePtrVector();
|
||||
}
|
||||
|
||||
virtual ~OpResolver() {}
|
||||
};
|
||||
|
||||
|
@ -1414,7 +1414,7 @@ TfLiteStatus Subgraph::ModifyGraphWithDelegate(TfLiteDelegate* delegate) {
|
||||
if (state_ == kStateInvokableAndImmutable) {
|
||||
ReportError(
|
||||
"ModifyGraphWithDelegate is disallowed when graph is immutable.");
|
||||
return kTfLiteError;
|
||||
return kTfLiteApplicationError;
|
||||
}
|
||||
|
||||
if (!(delegate->flags & kTfLiteDelegateFlagsAllowDynamicTensors)) {
|
||||
|
@ -558,12 +558,15 @@ class Subgraph {
|
||||
// be reallocated if the graph was modified (i.e., the caller does *not* need
|
||||
// to explicitly call |AllocateTensors()| again). If tensors were unallocated,
|
||||
// they will remain unallocated after delegate application.
|
||||
// Returns one of the following three status codes:
|
||||
// Returns one of the following status codes:
|
||||
// 1. kTfLiteOk: Delegation succeeded
|
||||
// 2. kTfLiteDelegateError: Delegation failed due to an error in the
|
||||
// delegate. The Subgraph has been restored to its pre-delegation state.
|
||||
// 2. kTfLiteDelegateError: Delegation failed due to an error *in the
|
||||
// delegate*. The Subgraph has been restored to its pre-delegation state.
|
||||
// NOTE: This reverts all delegates previously applied to the Subgraph.
|
||||
// 3. kTfLiteError: Unexpected/runtime failure.
|
||||
// 3. kTfLiteApplicationError : Delegation failed to be applied due to the
|
||||
// state that the TfLite runtime is in. However, the Subgraph is still in a
|
||||
// invokable state.
|
||||
// 4. kTfLiteError: Unexpected/runtime failure.
|
||||
TfLiteStatus ModifyGraphWithDelegate(TfLiteDelegate* delegate);
|
||||
|
||||
// This un-applies all delegates that have been applied till now, but retains
|
||||
|
@ -86,9 +86,8 @@ TfLiteQuantization GetQuantizationFromLegacy(
|
||||
} // namespace
|
||||
|
||||
Interpreter::Interpreter(ErrorReporter* error_reporter)
|
||||
: error_reporter_(error_reporter ? error_reporter : DefaultErrorReporter()),
|
||||
lazy_delegate_provider_(
|
||||
TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {})) {
|
||||
: error_reporter_(error_reporter ? error_reporter
|
||||
: DefaultErrorReporter()) {
|
||||
// TODO(b/128420794): Include the TFLite runtime version in the log.
|
||||
// Prod logging is useful for mobile platforms where scraping console logs is
|
||||
// critical for debugging.
|
||||
@ -184,21 +183,53 @@ TfLiteStatus Interpreter::SetVariables(std::vector<int> variables) {
|
||||
TfLiteStatus Interpreter::AllocateTensors() {
|
||||
// Apply the default delegate that TFLite will enable at this point to allow
|
||||
// other user-level delegates to be applied first.
|
||||
if (lazy_delegate_provider_) {
|
||||
// The execution will fall back to default implementation if the XNNPACK
|
||||
// delegate fails to be applied. Therefore, we ignore the return status
|
||||
// here and let it fall through the rest of the code.
|
||||
auto status = ModifyGraphWithDelegate(std::move(lazy_delegate_provider_));
|
||||
if (status != kTfLiteOk) {
|
||||
TF_LITE_REPORT_ERROR(
|
||||
error_reporter_,
|
||||
"Ignoring failed application of the default TensorFlow Lite "
|
||||
"delegate.");
|
||||
} else {
|
||||
TFLITE_LOG(TFLITE_LOG_INFO,
|
||||
"Successfully applied the default TensorFlow Lite delegate.");
|
||||
if (!lazy_delegate_providers_.empty()) {
|
||||
TFLITE_LOG(TFLITE_LOG_INFO,
|
||||
"Applying %zu TensorFlow Lite delegate(s) lazily.",
|
||||
lazy_delegate_providers_.size());
|
||||
// At the momement, XNNPACK delegate is the only one that might be applied
|
||||
// by default, in which case, the execution will fall back to default
|
||||
// implementation if the XNNPACK delegate fails to be applied. Therefore, we
|
||||
// ignore the return status here and let it fall through the rest of the
|
||||
// code.
|
||||
for (size_t i = 0; i < lazy_delegate_providers_.size(); ++i) {
|
||||
auto status =
|
||||
ModifyGraphWithDelegate(std::move(lazy_delegate_providers_[i]));
|
||||
switch (status) {
|
||||
case kTfLiteOk:
|
||||
TFLITE_LOG(TFLITE_LOG_INFO,
|
||||
"Successfully applied the default TensorFlow Lite "
|
||||
"delegate indexed at %zu.",
|
||||
i);
|
||||
break;
|
||||
case kTfLiteError:
|
||||
TF_LITE_REPORT_ERROR(error_reporter_,
|
||||
"Failed to apply the default TensorFlow Lite "
|
||||
"delegate indexed at %zu.",
|
||||
i);
|
||||
return kTfLiteError;
|
||||
case kTfLiteDelegateError:
|
||||
TF_LITE_REPORT_ERROR(
|
||||
error_reporter_,
|
||||
"Error in applying the default TensorFlow Lite delegate indexed "
|
||||
"at %zu, and all previously applied delegates are reverted.",
|
||||
i);
|
||||
break;
|
||||
case kTfLiteApplicationError:
|
||||
TF_LITE_REPORT_ERROR(error_reporter_,
|
||||
"Ignoring failed application of the default "
|
||||
"TensorFlow Lite delegate indexed at %zu.",
|
||||
i);
|
||||
break;
|
||||
default:
|
||||
TF_LITE_REPORT_ERROR(error_reporter_,
|
||||
"Unknown status (%d) after applying the default "
|
||||
"TensorFlow Lite delegate indexed at %zu.",
|
||||
status, i);
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
lazy_delegate_provider_.reset();
|
||||
lazy_delegate_providers_.clear();
|
||||
}
|
||||
|
||||
return primary_subgraph().AllocateTensors();
|
||||
|
@ -653,10 +653,10 @@ class Interpreter {
|
||||
// A map of resources. Owned by interpreter and shared by multiple subgraphs.
|
||||
resource::ResourceMap resources_;
|
||||
|
||||
// Indicating a delegate that the TFLite interpreter will apply by default.
|
||||
// A nullptr value means there's no delegate to be applied by default or the
|
||||
// delegate has been applied and doesn't need to be applied again.
|
||||
TfLiteDelegatePtr lazy_delegate_provider_;
|
||||
// Indicating delegates that the TFLite interpreter will apply by default.
|
||||
// An empty one means there's no delegate to be applied by default or
|
||||
// delegates have been applied and doesn't need to be applied again.
|
||||
std::vector<TfLiteDelegatePtr> lazy_delegate_providers_;
|
||||
};
|
||||
|
||||
} // namespace impl
|
||||
|
@ -29,7 +29,6 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
#include "tensorflow/lite/shared_library.h"
|
||||
#include "tensorflow/lite/tflite_with_xnnpack_optional.h"
|
||||
#include "tensorflow/lite/util.h"
|
||||
#include "tensorflow/lite/version.h"
|
||||
|
||||
@ -675,8 +674,8 @@ TfLiteStatus InterpreterBuilder::operator()(
|
||||
}
|
||||
|
||||
if (num_fp32_tensors_ > 0) {
|
||||
(*interpreter)->lazy_delegate_provider_ =
|
||||
MaybeCreateXNNPACKDelegate(num_threads);
|
||||
(*interpreter)->lazy_delegate_providers_ =
|
||||
op_resolver_.GetDelegates(num_threads);
|
||||
}
|
||||
|
||||
if (ApplyDelegates(interpreter->get(), num_threads) != kTfLiteOk)
|
||||
|
@ -757,6 +757,7 @@ cc_library(
|
||||
deps = [
|
||||
":builtin_op_kernels",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite:tflite_with_xnnpack_optional",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
],
|
||||
@ -774,6 +775,7 @@ cc_library(
|
||||
deps = [
|
||||
":builtin_op_kernels",
|
||||
"//tensorflow/lite:framework_lib",
|
||||
"//tensorflow/lite:tflite_with_xnnpack_optional",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
],
|
||||
@ -791,6 +793,7 @@ cc_library(
|
||||
deps = [
|
||||
":builtin_op_kernels_ruy_and_caching",
|
||||
"//tensorflow/lite:framework_lib",
|
||||
"//tensorflow/lite:tflite_with_xnnpack_optional",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
],
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
#include "tensorflow/lite/tflite_with_xnnpack_optional.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
@ -303,6 +304,21 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
||||
tflite::ops::custom::Register_DETECTION_POSTPROCESS());
|
||||
}
|
||||
|
||||
OpResolver::TfLiteDelegatePtrVector BuiltinOpResolver::GetDelegates(
|
||||
int num_threads) const {
|
||||
OpResolver::TfLiteDelegatePtrVector delegates;
|
||||
auto xnnpack_delegate = tflite::MaybeCreateXNNPACKDelegate(num_threads);
|
||||
if (xnnpack_delegate != nullptr) {
|
||||
delegates.push_back(std::move(xnnpack_delegate));
|
||||
}
|
||||
return delegates;
|
||||
}
|
||||
|
||||
OpResolver::TfLiteDelegatePtrVector
|
||||
BuiltinOpResolverWithoutDefaultDelegates::GetDelegates(int num_threads) const {
|
||||
return OpResolver::TfLiteDelegatePtrVector();
|
||||
}
|
||||
|
||||
} // namespace builtin
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
||||
|
@ -22,9 +22,22 @@ namespace tflite {
|
||||
namespace ops {
|
||||
namespace builtin {
|
||||
|
||||
// This built-in op resolver provides a list of TfLite delegates that could be
|
||||
// applied by TfLite interpreter by default.
|
||||
class BuiltinOpResolver : public MutableOpResolver {
|
||||
public:
|
||||
BuiltinOpResolver();
|
||||
OpResolver::TfLiteDelegatePtrVector GetDelegates(
|
||||
int num_threads) const override;
|
||||
};
|
||||
|
||||
// TfLite interpreter could apply a TfLite delegate by default. To completely
|
||||
// disable this behavior, one could choose to use the following class
|
||||
// BuiltinOpResolverWithoutDefaultDelegates.
|
||||
class BuiltinOpResolverWithoutDefaultDelegates : public BuiltinOpResolver {
|
||||
public:
|
||||
BuiltinOpResolverWithoutDefaultDelegates() : BuiltinOpResolver() {}
|
||||
OpResolver::TfLiteDelegatePtrVector GetDelegates(int num_threads) const final;
|
||||
};
|
||||
|
||||
} // namespace builtin
|
||||
|
@ -30,7 +30,7 @@ TEST(FloatModel, WithXnnpackDelegate) {
|
||||
|
||||
std::unique_ptr<Interpreter> interpreter;
|
||||
ASSERT_EQ(InterpreterBuilder(*model,
|
||||
ops::builtin::BuiltinOpResolver{})(&interpreter),
|
||||
ops::builtin::BuiltinOpResolver())(&interpreter),
|
||||
kTfLiteOk);
|
||||
ASSERT_TRUE(interpreter);
|
||||
|
||||
@ -48,4 +48,32 @@ TEST(FloatModel, WithXnnpackDelegate) {
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(FloatModel, DefaultXnnpackDelegateNotAllowed) {
|
||||
// Note: this graph will be fully delegated by the XNNPACK delegate.
|
||||
auto model = FlatBufferModel::BuildFromFile(
|
||||
"tensorflow/lite/testdata/multi_add.bin");
|
||||
ASSERT_TRUE(model);
|
||||
|
||||
std::unique_ptr<Interpreter> interpreter;
|
||||
ASSERT_EQ(
|
||||
InterpreterBuilder(
|
||||
*model, ops::builtin::BuiltinOpResolverWithoutDefaultDelegates())(
|
||||
&interpreter),
|
||||
kTfLiteOk);
|
||||
ASSERT_TRUE(interpreter);
|
||||
|
||||
ASSERT_EQ(interpreter->AllocateTensors(), kTfLiteOk);
|
||||
|
||||
#if TFLITE_HAS_ATTRIBUTE_WEAK || defined(TFLITE_BUILD_WITH_XNNPACK_DELEGATE)
|
||||
// As we don't allow applying xnnpack delegate by default, we will expect the
|
||||
// following:
|
||||
EXPECT_LT(1, interpreter->execution_plan().size());
|
||||
int first_node_id = interpreter->execution_plan()[0];
|
||||
const auto& first_node_reg =
|
||||
interpreter->node_and_registration(first_node_id)->second;
|
||||
const std::string op_name = GetOpNameByRegistration(first_node_reg);
|
||||
EXPECT_EQ("ADD", op_name);
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
@ -47,7 +47,8 @@ extern "C" {
|
||||
typedef enum TfLiteStatus {
|
||||
kTfLiteOk = 0,
|
||||
kTfLiteError = 1,
|
||||
kTfLiteDelegateError = 2
|
||||
kTfLiteDelegateError = 2,
|
||||
kTfLiteApplicationError = 3
|
||||
} TfLiteStatus;
|
||||
|
||||
// The list of external context types known to TF Lite. This list exists solely
|
||||
|
Loading…
Reference in New Issue
Block a user