Use BuiltinOpResolver as a way to apply the xnnpack delegate by default in TfLite interpreter. Also, provide another builtin-op resolver to disallow applying the delegate by default.
PiperOrigin-RevId: 327378746 Change-Id: I801790cf2878875fcf23c4781306e8243c8fd0af
This commit is contained in:
parent
7c8b6efc14
commit
4bb7dca446
@ -254,7 +254,6 @@ cc_library(
|
|||||||
":shared_library",
|
":shared_library",
|
||||||
":simple_memory_arena",
|
":simple_memory_arena",
|
||||||
":string",
|
":string",
|
||||||
":tflite_with_xnnpack_optional",
|
|
||||||
":type_to_tflitetype",
|
":type_to_tflitetype",
|
||||||
":util",
|
":util",
|
||||||
":version",
|
":version",
|
||||||
|
@ -47,7 +47,8 @@ extern "C" {
|
|||||||
typedef enum TfLiteStatus {
|
typedef enum TfLiteStatus {
|
||||||
kTfLiteOk = 0,
|
kTfLiteOk = 0,
|
||||||
kTfLiteError = 1,
|
kTfLiteError = 1,
|
||||||
kTfLiteDelegateError = 2
|
kTfLiteDelegateError = 2,
|
||||||
|
kTfLiteApplicationError = 3
|
||||||
} TfLiteStatus;
|
} TfLiteStatus;
|
||||||
|
|
||||||
// The list of external context types known to TF Lite. This list exists solely
|
// The list of external context types known to TF Lite. This list exists solely
|
||||||
|
@ -15,6 +15,8 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_LITE_CORE_API_OP_RESOLVER_H_
|
#ifndef TENSORFLOW_LITE_CORE_API_OP_RESOLVER_H_
|
||||||
#define TENSORFLOW_LITE_CORE_API_OP_RESOLVER_H_
|
#define TENSORFLOW_LITE_CORE_API_OP_RESOLVER_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/lite/c/common.h"
|
#include "tensorflow/lite/c/common.h"
|
||||||
#include "tensorflow/lite/core/api/error_reporter.h"
|
#include "tensorflow/lite/core/api/error_reporter.h"
|
||||||
#include "tensorflow/lite/schema/schema_generated.h"
|
#include "tensorflow/lite/schema/schema_generated.h"
|
||||||
@ -32,6 +34,16 @@ class OpResolver {
|
|||||||
/// Finds the op registration of a custom operator by op name.
|
/// Finds the op registration of a custom operator by op name.
|
||||||
virtual const TfLiteRegistration* FindOp(const char* op,
|
virtual const TfLiteRegistration* FindOp(const char* op,
|
||||||
int version) const = 0;
|
int version) const = 0;
|
||||||
|
|
||||||
|
// Returns optional delegates for resolving and handling ops in the flatbuffer
|
||||||
|
// model. This may be used in addition to the standard TfLiteRegistration
|
||||||
|
// lookup for graph resolution.
|
||||||
|
using TfLiteDelegatePtrVector =
|
||||||
|
std::vector<std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>>;
|
||||||
|
virtual TfLiteDelegatePtrVector GetDelegates(int num_threads) const {
|
||||||
|
return TfLiteDelegatePtrVector();
|
||||||
|
}
|
||||||
|
|
||||||
virtual ~OpResolver() {}
|
virtual ~OpResolver() {}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -1414,7 +1414,7 @@ TfLiteStatus Subgraph::ModifyGraphWithDelegate(TfLiteDelegate* delegate) {
|
|||||||
if (state_ == kStateInvokableAndImmutable) {
|
if (state_ == kStateInvokableAndImmutable) {
|
||||||
ReportError(
|
ReportError(
|
||||||
"ModifyGraphWithDelegate is disallowed when graph is immutable.");
|
"ModifyGraphWithDelegate is disallowed when graph is immutable.");
|
||||||
return kTfLiteError;
|
return kTfLiteApplicationError;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!(delegate->flags & kTfLiteDelegateFlagsAllowDynamicTensors)) {
|
if (!(delegate->flags & kTfLiteDelegateFlagsAllowDynamicTensors)) {
|
||||||
|
@ -558,12 +558,15 @@ class Subgraph {
|
|||||||
// be reallocated if the graph was modified (i.e., the caller does *not* need
|
// be reallocated if the graph was modified (i.e., the caller does *not* need
|
||||||
// to explicitly call |AllocateTensors()| again). If tensors were unallocated,
|
// to explicitly call |AllocateTensors()| again). If tensors were unallocated,
|
||||||
// they will remain unallocated after delegate application.
|
// they will remain unallocated after delegate application.
|
||||||
// Returns one of the following three status codes:
|
// Returns one of the following status codes:
|
||||||
// 1. kTfLiteOk: Delegation succeeded
|
// 1. kTfLiteOk: Delegation succeeded
|
||||||
// 2. kTfLiteDelegateError: Delegation failed due to an error in the
|
// 2. kTfLiteDelegateError: Delegation failed due to an error *in the
|
||||||
// delegate. The Subgraph has been restored to its pre-delegation state.
|
// delegate*. The Subgraph has been restored to its pre-delegation state.
|
||||||
// NOTE: This reverts all delegates previously applied to the Subgraph.
|
// NOTE: This reverts all delegates previously applied to the Subgraph.
|
||||||
// 3. kTfLiteError: Unexpected/runtime failure.
|
// 3. kTfLiteApplicationError : Delegation failed to be applied due to the
|
||||||
|
// state that the TfLite runtime is in. However, the Subgraph is still in a
|
||||||
|
// invokable state.
|
||||||
|
// 4. kTfLiteError: Unexpected/runtime failure.
|
||||||
TfLiteStatus ModifyGraphWithDelegate(TfLiteDelegate* delegate);
|
TfLiteStatus ModifyGraphWithDelegate(TfLiteDelegate* delegate);
|
||||||
|
|
||||||
// This un-applies all delegates that have been applied till now, but retains
|
// This un-applies all delegates that have been applied till now, but retains
|
||||||
|
@ -86,9 +86,8 @@ TfLiteQuantization GetQuantizationFromLegacy(
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Interpreter::Interpreter(ErrorReporter* error_reporter)
|
Interpreter::Interpreter(ErrorReporter* error_reporter)
|
||||||
: error_reporter_(error_reporter ? error_reporter : DefaultErrorReporter()),
|
: error_reporter_(error_reporter ? error_reporter
|
||||||
lazy_delegate_provider_(
|
: DefaultErrorReporter()) {
|
||||||
TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {})) {
|
|
||||||
// TODO(b/128420794): Include the TFLite runtime version in the log.
|
// TODO(b/128420794): Include the TFLite runtime version in the log.
|
||||||
// Prod logging is useful for mobile platforms where scraping console logs is
|
// Prod logging is useful for mobile platforms where scraping console logs is
|
||||||
// critical for debugging.
|
// critical for debugging.
|
||||||
@ -184,21 +183,53 @@ TfLiteStatus Interpreter::SetVariables(std::vector<int> variables) {
|
|||||||
TfLiteStatus Interpreter::AllocateTensors() {
|
TfLiteStatus Interpreter::AllocateTensors() {
|
||||||
// Apply the default delegate that TFLite will enable at this point to allow
|
// Apply the default delegate that TFLite will enable at this point to allow
|
||||||
// other user-level delegates to be applied first.
|
// other user-level delegates to be applied first.
|
||||||
if (lazy_delegate_provider_) {
|
if (!lazy_delegate_providers_.empty()) {
|
||||||
// The execution will fall back to default implementation if the XNNPACK
|
TFLITE_LOG(TFLITE_LOG_INFO,
|
||||||
// delegate fails to be applied. Therefore, we ignore the return status
|
"Applying %zu TensorFlow Lite delegate(s) lazily.",
|
||||||
// here and let it fall through the rest of the code.
|
lazy_delegate_providers_.size());
|
||||||
auto status = ModifyGraphWithDelegate(std::move(lazy_delegate_provider_));
|
// At the momement, XNNPACK delegate is the only one that might be applied
|
||||||
if (status != kTfLiteOk) {
|
// by default, in which case, the execution will fall back to default
|
||||||
|
// implementation if the XNNPACK delegate fails to be applied. Therefore, we
|
||||||
|
// ignore the return status here and let it fall through the rest of the
|
||||||
|
// code.
|
||||||
|
for (size_t i = 0; i < lazy_delegate_providers_.size(); ++i) {
|
||||||
|
auto status =
|
||||||
|
ModifyGraphWithDelegate(std::move(lazy_delegate_providers_[i]));
|
||||||
|
switch (status) {
|
||||||
|
case kTfLiteOk:
|
||||||
|
TFLITE_LOG(TFLITE_LOG_INFO,
|
||||||
|
"Successfully applied the default TensorFlow Lite "
|
||||||
|
"delegate indexed at %zu.",
|
||||||
|
i);
|
||||||
|
break;
|
||||||
|
case kTfLiteError:
|
||||||
|
TF_LITE_REPORT_ERROR(error_reporter_,
|
||||||
|
"Failed to apply the default TensorFlow Lite "
|
||||||
|
"delegate indexed at %zu.",
|
||||||
|
i);
|
||||||
|
return kTfLiteError;
|
||||||
|
case kTfLiteDelegateError:
|
||||||
TF_LITE_REPORT_ERROR(
|
TF_LITE_REPORT_ERROR(
|
||||||
error_reporter_,
|
error_reporter_,
|
||||||
"Ignoring failed application of the default TensorFlow Lite "
|
"Error in applying the default TensorFlow Lite delegate indexed "
|
||||||
"delegate.");
|
"at %zu, and all previously applied delegates are reverted.",
|
||||||
} else {
|
i);
|
||||||
TFLITE_LOG(TFLITE_LOG_INFO,
|
break;
|
||||||
"Successfully applied the default TensorFlow Lite delegate.");
|
case kTfLiteApplicationError:
|
||||||
|
TF_LITE_REPORT_ERROR(error_reporter_,
|
||||||
|
"Ignoring failed application of the default "
|
||||||
|
"TensorFlow Lite delegate indexed at %zu.",
|
||||||
|
i);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
TF_LITE_REPORT_ERROR(error_reporter_,
|
||||||
|
"Unknown status (%d) after applying the default "
|
||||||
|
"TensorFlow Lite delegate indexed at %zu.",
|
||||||
|
status, i);
|
||||||
|
return kTfLiteError;
|
||||||
}
|
}
|
||||||
lazy_delegate_provider_.reset();
|
}
|
||||||
|
lazy_delegate_providers_.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
return primary_subgraph().AllocateTensors();
|
return primary_subgraph().AllocateTensors();
|
||||||
|
@ -653,10 +653,10 @@ class Interpreter {
|
|||||||
// A map of resources. Owned by interpreter and shared by multiple subgraphs.
|
// A map of resources. Owned by interpreter and shared by multiple subgraphs.
|
||||||
resource::ResourceMap resources_;
|
resource::ResourceMap resources_;
|
||||||
|
|
||||||
// Indicating a delegate that the TFLite interpreter will apply by default.
|
// Indicating delegates that the TFLite interpreter will apply by default.
|
||||||
// A nullptr value means there's no delegate to be applied by default or the
|
// An empty one means there's no delegate to be applied by default or
|
||||||
// delegate has been applied and doesn't need to be applied again.
|
// delegates have been applied and doesn't need to be applied again.
|
||||||
TfLiteDelegatePtr lazy_delegate_provider_;
|
std::vector<TfLiteDelegatePtr> lazy_delegate_providers_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace impl
|
} // namespace impl
|
||||||
|
@ -29,7 +29,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||||
#include "tensorflow/lite/schema/schema_generated.h"
|
#include "tensorflow/lite/schema/schema_generated.h"
|
||||||
#include "tensorflow/lite/shared_library.h"
|
#include "tensorflow/lite/shared_library.h"
|
||||||
#include "tensorflow/lite/tflite_with_xnnpack_optional.h"
|
|
||||||
#include "tensorflow/lite/util.h"
|
#include "tensorflow/lite/util.h"
|
||||||
#include "tensorflow/lite/version.h"
|
#include "tensorflow/lite/version.h"
|
||||||
|
|
||||||
@ -675,8 +674,8 @@ TfLiteStatus InterpreterBuilder::operator()(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (num_fp32_tensors_ > 0) {
|
if (num_fp32_tensors_ > 0) {
|
||||||
(*interpreter)->lazy_delegate_provider_ =
|
(*interpreter)->lazy_delegate_providers_ =
|
||||||
MaybeCreateXNNPACKDelegate(num_threads);
|
op_resolver_.GetDelegates(num_threads);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ApplyDelegates(interpreter->get(), num_threads) != kTfLiteOk)
|
if (ApplyDelegates(interpreter->get(), num_threads) != kTfLiteOk)
|
||||||
|
@ -757,6 +757,7 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":builtin_op_kernels",
|
":builtin_op_kernels",
|
||||||
"//tensorflow/lite:framework",
|
"//tensorflow/lite:framework",
|
||||||
|
"//tensorflow/lite:tflite_with_xnnpack_optional",
|
||||||
"//tensorflow/lite/c:common",
|
"//tensorflow/lite/c:common",
|
||||||
"//tensorflow/lite/schema:schema_fbs",
|
"//tensorflow/lite/schema:schema_fbs",
|
||||||
],
|
],
|
||||||
@ -774,6 +775,7 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":builtin_op_kernels",
|
":builtin_op_kernels",
|
||||||
"//tensorflow/lite:framework_lib",
|
"//tensorflow/lite:framework_lib",
|
||||||
|
"//tensorflow/lite:tflite_with_xnnpack_optional",
|
||||||
"//tensorflow/lite/c:common",
|
"//tensorflow/lite/c:common",
|
||||||
"//tensorflow/lite/schema:schema_fbs",
|
"//tensorflow/lite/schema:schema_fbs",
|
||||||
],
|
],
|
||||||
@ -791,6 +793,7 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":builtin_op_kernels_ruy_and_caching",
|
":builtin_op_kernels_ruy_and_caching",
|
||||||
"//tensorflow/lite:framework_lib",
|
"//tensorflow/lite:framework_lib",
|
||||||
|
"//tensorflow/lite:tflite_with_xnnpack_optional",
|
||||||
"//tensorflow/lite/c:common",
|
"//tensorflow/lite/c:common",
|
||||||
"//tensorflow/lite/schema:schema_fbs",
|
"//tensorflow/lite/schema:schema_fbs",
|
||||||
],
|
],
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/c/common.h"
|
#include "tensorflow/lite/c/common.h"
|
||||||
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
|
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
|
||||||
#include "tensorflow/lite/schema/schema_generated.h"
|
#include "tensorflow/lite/schema/schema_generated.h"
|
||||||
|
#include "tensorflow/lite/tflite_with_xnnpack_optional.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
@ -303,6 +304,21 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
|||||||
tflite::ops::custom::Register_DETECTION_POSTPROCESS());
|
tflite::ops::custom::Register_DETECTION_POSTPROCESS());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
OpResolver::TfLiteDelegatePtrVector BuiltinOpResolver::GetDelegates(
|
||||||
|
int num_threads) const {
|
||||||
|
OpResolver::TfLiteDelegatePtrVector delegates;
|
||||||
|
auto xnnpack_delegate = tflite::MaybeCreateXNNPACKDelegate(num_threads);
|
||||||
|
if (xnnpack_delegate != nullptr) {
|
||||||
|
delegates.push_back(std::move(xnnpack_delegate));
|
||||||
|
}
|
||||||
|
return delegates;
|
||||||
|
}
|
||||||
|
|
||||||
|
OpResolver::TfLiteDelegatePtrVector
|
||||||
|
BuiltinOpResolverWithoutDefaultDelegates::GetDelegates(int num_threads) const {
|
||||||
|
return OpResolver::TfLiteDelegatePtrVector();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace builtin
|
} // namespace builtin
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -22,9 +22,22 @@ namespace tflite {
|
|||||||
namespace ops {
|
namespace ops {
|
||||||
namespace builtin {
|
namespace builtin {
|
||||||
|
|
||||||
|
// This built-in op resolver provides a list of TfLite delegates that could be
|
||||||
|
// applied by TfLite interpreter by default.
|
||||||
class BuiltinOpResolver : public MutableOpResolver {
|
class BuiltinOpResolver : public MutableOpResolver {
|
||||||
public:
|
public:
|
||||||
BuiltinOpResolver();
|
BuiltinOpResolver();
|
||||||
|
OpResolver::TfLiteDelegatePtrVector GetDelegates(
|
||||||
|
int num_threads) const override;
|
||||||
|
};
|
||||||
|
|
||||||
|
// TfLite interpreter could apply a TfLite delegate by default. To completely
|
||||||
|
// disable this behavior, one could choose to use the following class
|
||||||
|
// BuiltinOpResolverWithoutDefaultDelegates.
|
||||||
|
class BuiltinOpResolverWithoutDefaultDelegates : public BuiltinOpResolver {
|
||||||
|
public:
|
||||||
|
BuiltinOpResolverWithoutDefaultDelegates() : BuiltinOpResolver() {}
|
||||||
|
OpResolver::TfLiteDelegatePtrVector GetDelegates(int num_threads) const final;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace builtin
|
} // namespace builtin
|
||||||
|
@ -30,7 +30,7 @@ TEST(FloatModel, WithXnnpackDelegate) {
|
|||||||
|
|
||||||
std::unique_ptr<Interpreter> interpreter;
|
std::unique_ptr<Interpreter> interpreter;
|
||||||
ASSERT_EQ(InterpreterBuilder(*model,
|
ASSERT_EQ(InterpreterBuilder(*model,
|
||||||
ops::builtin::BuiltinOpResolver{})(&interpreter),
|
ops::builtin::BuiltinOpResolver())(&interpreter),
|
||||||
kTfLiteOk);
|
kTfLiteOk);
|
||||||
ASSERT_TRUE(interpreter);
|
ASSERT_TRUE(interpreter);
|
||||||
|
|
||||||
@ -48,4 +48,32 @@ TEST(FloatModel, WithXnnpackDelegate) {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(FloatModel, DefaultXnnpackDelegateNotAllowed) {
|
||||||
|
// Note: this graph will be fully delegated by the XNNPACK delegate.
|
||||||
|
auto model = FlatBufferModel::BuildFromFile(
|
||||||
|
"tensorflow/lite/testdata/multi_add.bin");
|
||||||
|
ASSERT_TRUE(model);
|
||||||
|
|
||||||
|
std::unique_ptr<Interpreter> interpreter;
|
||||||
|
ASSERT_EQ(
|
||||||
|
InterpreterBuilder(
|
||||||
|
*model, ops::builtin::BuiltinOpResolverWithoutDefaultDelegates())(
|
||||||
|
&interpreter),
|
||||||
|
kTfLiteOk);
|
||||||
|
ASSERT_TRUE(interpreter);
|
||||||
|
|
||||||
|
ASSERT_EQ(interpreter->AllocateTensors(), kTfLiteOk);
|
||||||
|
|
||||||
|
#if TFLITE_HAS_ATTRIBUTE_WEAK || defined(TFLITE_BUILD_WITH_XNNPACK_DELEGATE)
|
||||||
|
// As we don't allow applying xnnpack delegate by default, we will expect the
|
||||||
|
// following:
|
||||||
|
EXPECT_LT(1, interpreter->execution_plan().size());
|
||||||
|
int first_node_id = interpreter->execution_plan()[0];
|
||||||
|
const auto& first_node_reg =
|
||||||
|
interpreter->node_and_registration(first_node_id)->second;
|
||||||
|
const std::string op_name = GetOpNameByRegistration(first_node_reg);
|
||||||
|
EXPECT_EQ("ADD", op_name);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -47,7 +47,8 @@ extern "C" {
|
|||||||
typedef enum TfLiteStatus {
|
typedef enum TfLiteStatus {
|
||||||
kTfLiteOk = 0,
|
kTfLiteOk = 0,
|
||||||
kTfLiteError = 1,
|
kTfLiteError = 1,
|
||||||
kTfLiteDelegateError = 2
|
kTfLiteDelegateError = 2,
|
||||||
|
kTfLiteApplicationError = 3
|
||||||
} TfLiteStatus;
|
} TfLiteStatus;
|
||||||
|
|
||||||
// The list of external context types known to TF Lite. This list exists solely
|
// The list of external context types known to TF Lite. This list exists solely
|
||||||
|
Loading…
Reference in New Issue
Block a user