Use BuiltinOpResolver as a way to apply the xnnpack delegate by default in TfLite interpreter. Also, provide another builtin-op resolver to disallow applying the delegate by default.

PiperOrigin-RevId: 327378746
Change-Id: I801790cf2878875fcf23c4781306e8243c8fd0af
This commit is contained in:
Chao Mei 2020-08-18 23:15:22 -07:00 committed by TensorFlower Gardener
parent 7c8b6efc14
commit 4bb7dca446
13 changed files with 139 additions and 33 deletions

View File

@ -254,7 +254,6 @@ cc_library(
":shared_library",
":simple_memory_arena",
":string",
":tflite_with_xnnpack_optional",
":type_to_tflitetype",
":util",
":version",

View File

@ -47,7 +47,8 @@ extern "C" {
typedef enum TfLiteStatus {
kTfLiteOk = 0,
kTfLiteError = 1,
kTfLiteDelegateError = 2
kTfLiteDelegateError = 2,
kTfLiteApplicationError = 3
} TfLiteStatus;
// The list of external context types known to TF Lite. This list exists solely

View File

@ -15,6 +15,8 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_CORE_API_OP_RESOLVER_H_
#define TENSORFLOW_LITE_CORE_API_OP_RESOLVER_H_
#include <vector>
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/schema/schema_generated.h"
@ -32,6 +34,16 @@ class OpResolver {
/// Finds the op registration of a custom operator by op name.
virtual const TfLiteRegistration* FindOp(const char* op,
int version) const = 0;
// Returns optional delegates for resolving and handling ops in the flatbuffer
// model. This may be used in addition to the standard TfLiteRegistration
// lookup for graph resolution.
using TfLiteDelegatePtrVector =
std::vector<std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>>;
virtual TfLiteDelegatePtrVector GetDelegates(int num_threads) const {
return TfLiteDelegatePtrVector();
}
virtual ~OpResolver() {}
};

View File

@ -1414,7 +1414,7 @@ TfLiteStatus Subgraph::ModifyGraphWithDelegate(TfLiteDelegate* delegate) {
if (state_ == kStateInvokableAndImmutable) {
ReportError(
"ModifyGraphWithDelegate is disallowed when graph is immutable.");
return kTfLiteError;
return kTfLiteApplicationError;
}
if (!(delegate->flags & kTfLiteDelegateFlagsAllowDynamicTensors)) {

View File

@ -558,12 +558,15 @@ class Subgraph {
// be reallocated if the graph was modified (i.e., the caller does *not* need
// to explicitly call |AllocateTensors()| again). If tensors were unallocated,
// they will remain unallocated after delegate application.
// Returns one of the following three status codes:
// Returns one of the following status codes:
// 1. kTfLiteOk: Delegation succeeded
// 2. kTfLiteDelegateError: Delegation failed due to an error in the
// delegate. The Subgraph has been restored to its pre-delegation state.
// 2. kTfLiteDelegateError: Delegation failed due to an error *in the
// delegate*. The Subgraph has been restored to its pre-delegation state.
// NOTE: This reverts all delegates previously applied to the Subgraph.
// 3. kTfLiteError: Unexpected/runtime failure.
// 3. kTfLiteApplicationError : Delegation failed to be applied due to the
// state that the TfLite runtime is in. However, the Subgraph is still in a
// invokable state.
// 4. kTfLiteError: Unexpected/runtime failure.
TfLiteStatus ModifyGraphWithDelegate(TfLiteDelegate* delegate);
// This un-applies all delegates that have been applied till now, but retains

View File

@ -86,9 +86,8 @@ TfLiteQuantization GetQuantizationFromLegacy(
} // namespace
Interpreter::Interpreter(ErrorReporter* error_reporter)
: error_reporter_(error_reporter ? error_reporter : DefaultErrorReporter()),
lazy_delegate_provider_(
TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {})) {
: error_reporter_(error_reporter ? error_reporter
: DefaultErrorReporter()) {
// TODO(b/128420794): Include the TFLite runtime version in the log.
// Prod logging is useful for mobile platforms where scraping console logs is
// critical for debugging.
@ -184,21 +183,53 @@ TfLiteStatus Interpreter::SetVariables(std::vector<int> variables) {
TfLiteStatus Interpreter::AllocateTensors() {
// Apply the default delegate that TFLite will enable at this point to allow
// other user-level delegates to be applied first.
if (lazy_delegate_provider_) {
// The execution will fall back to default implementation if the XNNPACK
// delegate fails to be applied. Therefore, we ignore the return status
// here and let it fall through the rest of the code.
auto status = ModifyGraphWithDelegate(std::move(lazy_delegate_provider_));
if (status != kTfLiteOk) {
TF_LITE_REPORT_ERROR(
error_reporter_,
"Ignoring failed application of the default TensorFlow Lite "
"delegate.");
} else {
TFLITE_LOG(TFLITE_LOG_INFO,
"Successfully applied the default TensorFlow Lite delegate.");
if (!lazy_delegate_providers_.empty()) {
TFLITE_LOG(TFLITE_LOG_INFO,
"Applying %zu TensorFlow Lite delegate(s) lazily.",
lazy_delegate_providers_.size());
// At the momement, XNNPACK delegate is the only one that might be applied
// by default, in which case, the execution will fall back to default
// implementation if the XNNPACK delegate fails to be applied. Therefore, we
// ignore the return status here and let it fall through the rest of the
// code.
for (size_t i = 0; i < lazy_delegate_providers_.size(); ++i) {
auto status =
ModifyGraphWithDelegate(std::move(lazy_delegate_providers_[i]));
switch (status) {
case kTfLiteOk:
TFLITE_LOG(TFLITE_LOG_INFO,
"Successfully applied the default TensorFlow Lite "
"delegate indexed at %zu.",
i);
break;
case kTfLiteError:
TF_LITE_REPORT_ERROR(error_reporter_,
"Failed to apply the default TensorFlow Lite "
"delegate indexed at %zu.",
i);
return kTfLiteError;
case kTfLiteDelegateError:
TF_LITE_REPORT_ERROR(
error_reporter_,
"Error in applying the default TensorFlow Lite delegate indexed "
"at %zu, and all previously applied delegates are reverted.",
i);
break;
case kTfLiteApplicationError:
TF_LITE_REPORT_ERROR(error_reporter_,
"Ignoring failed application of the default "
"TensorFlow Lite delegate indexed at %zu.",
i);
break;
default:
TF_LITE_REPORT_ERROR(error_reporter_,
"Unknown status (%d) after applying the default "
"TensorFlow Lite delegate indexed at %zu.",
status, i);
return kTfLiteError;
}
}
lazy_delegate_provider_.reset();
lazy_delegate_providers_.clear();
}
return primary_subgraph().AllocateTensors();

View File

@ -653,10 +653,10 @@ class Interpreter {
// A map of resources. Owned by interpreter and shared by multiple subgraphs.
resource::ResourceMap resources_;
// Indicating a delegate that the TFLite interpreter will apply by default.
// A nullptr value means there's no delegate to be applied by default or the
// delegate has been applied and doesn't need to be applied again.
TfLiteDelegatePtr lazy_delegate_provider_;
// Indicating delegates that the TFLite interpreter will apply by default.
// An empty one means there's no delegate to be applied by default or
// delegates have been applied and doesn't need to be applied again.
std::vector<TfLiteDelegatePtr> lazy_delegate_providers_;
};
} // namespace impl

View File

@ -29,7 +29,6 @@ limitations under the License.
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/shared_library.h"
#include "tensorflow/lite/tflite_with_xnnpack_optional.h"
#include "tensorflow/lite/util.h"
#include "tensorflow/lite/version.h"
@ -675,8 +674,8 @@ TfLiteStatus InterpreterBuilder::operator()(
}
if (num_fp32_tensors_ > 0) {
(*interpreter)->lazy_delegate_provider_ =
MaybeCreateXNNPACKDelegate(num_threads);
(*interpreter)->lazy_delegate_providers_ =
op_resolver_.GetDelegates(num_threads);
}
if (ApplyDelegates(interpreter->get(), num_threads) != kTfLiteOk)

View File

@ -757,6 +757,7 @@ cc_library(
deps = [
":builtin_op_kernels",
"//tensorflow/lite:framework",
"//tensorflow/lite:tflite_with_xnnpack_optional",
"//tensorflow/lite/c:common",
"//tensorflow/lite/schema:schema_fbs",
],
@ -774,6 +775,7 @@ cc_library(
deps = [
":builtin_op_kernels",
"//tensorflow/lite:framework_lib",
"//tensorflow/lite:tflite_with_xnnpack_optional",
"//tensorflow/lite/c:common",
"//tensorflow/lite/schema:schema_fbs",
],
@ -791,6 +793,7 @@ cc_library(
deps = [
":builtin_op_kernels_ruy_and_caching",
"//tensorflow/lite:framework_lib",
"//tensorflow/lite:tflite_with_xnnpack_optional",
"//tensorflow/lite/c:common",
"//tensorflow/lite/schema:schema_fbs",
],

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/tflite_with_xnnpack_optional.h"
namespace tflite {
namespace ops {
@ -303,6 +304,21 @@ BuiltinOpResolver::BuiltinOpResolver() {
tflite::ops::custom::Register_DETECTION_POSTPROCESS());
}
OpResolver::TfLiteDelegatePtrVector BuiltinOpResolver::GetDelegates(
int num_threads) const {
OpResolver::TfLiteDelegatePtrVector delegates;
auto xnnpack_delegate = tflite::MaybeCreateXNNPACKDelegate(num_threads);
if (xnnpack_delegate != nullptr) {
delegates.push_back(std::move(xnnpack_delegate));
}
return delegates;
}
OpResolver::TfLiteDelegatePtrVector
BuiltinOpResolverWithoutDefaultDelegates::GetDelegates(int num_threads) const {
return OpResolver::TfLiteDelegatePtrVector();
}
} // namespace builtin
} // namespace ops
} // namespace tflite

View File

@ -22,9 +22,22 @@ namespace tflite {
namespace ops {
namespace builtin {
// This built-in op resolver provides a list of TfLite delegates that could be
// applied by TfLite interpreter by default.
class BuiltinOpResolver : public MutableOpResolver {
public:
BuiltinOpResolver();
OpResolver::TfLiteDelegatePtrVector GetDelegates(
int num_threads) const override;
};
// TfLite interpreter could apply a TfLite delegate by default. To completely
// disable this behavior, one could choose to use the following class
// BuiltinOpResolverWithoutDefaultDelegates.
class BuiltinOpResolverWithoutDefaultDelegates : public BuiltinOpResolver {
public:
BuiltinOpResolverWithoutDefaultDelegates() : BuiltinOpResolver() {}
OpResolver::TfLiteDelegatePtrVector GetDelegates(int num_threads) const final;
};
} // namespace builtin

View File

@ -30,7 +30,7 @@ TEST(FloatModel, WithXnnpackDelegate) {
std::unique_ptr<Interpreter> interpreter;
ASSERT_EQ(InterpreterBuilder(*model,
ops::builtin::BuiltinOpResolver{})(&interpreter),
ops::builtin::BuiltinOpResolver())(&interpreter),
kTfLiteOk);
ASSERT_TRUE(interpreter);
@ -48,4 +48,32 @@ TEST(FloatModel, WithXnnpackDelegate) {
#endif
}
TEST(FloatModel, DefaultXnnpackDelegateNotAllowed) {
// Note: this graph will be fully delegated by the XNNPACK delegate.
auto model = FlatBufferModel::BuildFromFile(
"tensorflow/lite/testdata/multi_add.bin");
ASSERT_TRUE(model);
std::unique_ptr<Interpreter> interpreter;
ASSERT_EQ(
InterpreterBuilder(
*model, ops::builtin::BuiltinOpResolverWithoutDefaultDelegates())(
&interpreter),
kTfLiteOk);
ASSERT_TRUE(interpreter);
ASSERT_EQ(interpreter->AllocateTensors(), kTfLiteOk);
#if TFLITE_HAS_ATTRIBUTE_WEAK || defined(TFLITE_BUILD_WITH_XNNPACK_DELEGATE)
// As we don't allow applying xnnpack delegate by default, we will expect the
// following:
EXPECT_LT(1, interpreter->execution_plan().size());
int first_node_id = interpreter->execution_plan()[0];
const auto& first_node_reg =
interpreter->node_and_registration(first_node_id)->second;
const std::string op_name = GetOpNameByRegistration(first_node_reg);
EXPECT_EQ("ADD", op_name);
#endif
}
} // namespace tflite

View File

@ -47,7 +47,8 @@ extern "C" {
typedef enum TfLiteStatus {
kTfLiteOk = 0,
kTfLiteError = 1,
kTfLiteDelegateError = 2
kTfLiteDelegateError = 2,
kTfLiteApplicationError = 3
} TfLiteStatus;
// The list of external context types known to TF Lite. This list exists solely