Use the global delegate providers to drive the NNAPI delegate test.
PiperOrigin-RevId: 321271726 Change-Id: If50f8ca0bd712fbab9c041bf2db416b00820a251
This commit is contained in:
parent
2256faa6b7
commit
d956c282fb
@ -26,9 +26,13 @@ void InitKernelTest(int* argc, char** argv) {
|
||||
tflite::KernelTestDelegateProviders::Get();
|
||||
delegate_providers->InitFromCmdlineArgs(argc, const_cast<const char**>(argv));
|
||||
|
||||
// TODO(b/160764491): remove the special handling of NNAPI delegate test.
|
||||
tflite::SingleOpModel::SetForceUseNnapi(
|
||||
delegate_providers->ConstParams().Get<bool>("use_nnapi"));
|
||||
if (delegate_providers->ConstParams().Get<bool>("use_nnapi")) {
|
||||
// In Android Q, the NNAPI delegate avoids delegation if the only device
|
||||
// is the reference CPU. However, for testing purposes, we still want
|
||||
// delegation coverage, so force use of this reference path.
|
||||
delegate_providers->MutableParams()->Set<std::string>(
|
||||
"nnapi_accelerator_name", "nnapi-reference");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -55,26 +55,6 @@ namespace tflite {
|
||||
using ::testing::FloatNear;
|
||||
using ::testing::Matcher;
|
||||
|
||||
namespace {
|
||||
|
||||
// Whether to enable (global) use of NNAPI. Note that this will typically
|
||||
// be set via a command-line flag.
|
||||
static bool force_use_nnapi = false;
|
||||
|
||||
TfLiteDelegate* TestNnApiDelegate() {
|
||||
static TfLiteDelegate* delegate = [] {
|
||||
StatefulNnApiDelegate::Options options;
|
||||
// In Android Q, the NNAPI delegate avoids delegation if the only device
|
||||
// is the reference CPU. However, for testing purposes, we still want
|
||||
// delegation coverage, so force use of this reference path.
|
||||
options.accelerator_name = "nnapi-reference";
|
||||
return new StatefulNnApiDelegate(options);
|
||||
}();
|
||||
return delegate;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values,
|
||||
float max_abs_error) {
|
||||
std::vector<Matcher<float>> matchers;
|
||||
@ -221,26 +201,22 @@ void SingleOpModel::BuildInterpreter(std::vector<std::vector<int>> input_shapes,
|
||||
}
|
||||
|
||||
TfLiteStatus SingleOpModel::ApplyDelegate() {
|
||||
auto* delegate_providers = tflite::KernelTestDelegateProviders::Get();
|
||||
|
||||
if (force_use_nnapi) {
|
||||
delegate_ = TestNnApiDelegate();
|
||||
|
||||
// As we currently have special handling of nnapi delegate in kernel tests,
|
||||
// we turn off the nnapi delegate provider to avoid re-applying it later.
|
||||
// TODO(b/160764491): remove this special handling for NNAPI delegate test.
|
||||
delegate_providers->MutableParams()->Set<bool>("use_nnapi", false);
|
||||
}
|
||||
|
||||
if (delegate_) {
|
||||
TFLITE_LOG(WARN) << "Having a manually-set TfLite delegate, and bypassing "
|
||||
"KernelTestDelegateProviders";
|
||||
return interpreter_->ModifyGraphWithDelegate(delegate_);
|
||||
}
|
||||
|
||||
for (auto& one : delegate_providers->CreateAllDelegates()) {
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
interpreter_->ModifyGraphWithDelegate(std::move(one)));
|
||||
TF_LITE_ENSURE_STATUS(interpreter_->ModifyGraphWithDelegate(delegate_));
|
||||
++num_applied_delegates_;
|
||||
} else {
|
||||
auto* delegate_providers = tflite::KernelTestDelegateProviders::Get();
|
||||
for (auto& one : delegate_providers->CreateAllDelegates()) {
|
||||
// The raw ptr always points to the actual TfLiteDegate object.
|
||||
auto* delegate_raw_ptr = one.get();
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
interpreter_->ModifyGraphWithDelegate(std::move(one)));
|
||||
// Note: 'delegate_' is always set to the last successfully applied one.
|
||||
delegate_ = delegate_raw_ptr;
|
||||
++num_applied_delegates_;
|
||||
}
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
@ -257,13 +233,11 @@ void SingleOpModel::BuildInterpreter(
|
||||
}
|
||||
|
||||
// static
|
||||
void SingleOpModel::SetForceUseNnapi(bool use_nnapi) {
|
||||
force_use_nnapi = use_nnapi;
|
||||
bool SingleOpModel::GetForceUseNnapi() {
|
||||
return tflite::KernelTestDelegateProviders::Get()->ConstParams().Get<bool>(
|
||||
"use_nnapi");
|
||||
}
|
||||
|
||||
// static
|
||||
bool SingleOpModel::GetForceUseNnapi() { return force_use_nnapi; }
|
||||
|
||||
int32_t SingleOpModel::GetTensorSize(int index) const {
|
||||
TfLiteTensor* t = interpreter_->tensor(index);
|
||||
CHECK(t);
|
||||
@ -342,20 +316,27 @@ void SingleOpModel::ExpectOpAcceleratedWithNnapi(const std::string& test_id) {
|
||||
return;
|
||||
}
|
||||
|
||||
// If we have multiple delegates applied, we would skip this check at the
|
||||
// moment.
|
||||
if (num_applied_delegates_ > 1) {
|
||||
TFLITE_LOG(WARN) << "Skipping ExpectOpAcceleratedWithNnapi as "
|
||||
<< num_applied_delegates_
|
||||
<< " delegates have been successfully applied.";
|
||||
return;
|
||||
}
|
||||
TFLITE_LOG(INFO) << "Validating acceleration";
|
||||
const NnApi* nnapi = NnApiImplementation();
|
||||
if (nnapi && nnapi->nnapi_exists &&
|
||||
nnapi->android_sdk_version >=
|
||||
validation_params.value().MinAndroidSdkVersion()) {
|
||||
EXPECT_EQ(
|
||||
CountPartitionsDelegatedTo(interpreter_.get(), TestNnApiDelegate()), 1)
|
||||
EXPECT_EQ(CountPartitionsDelegatedTo(interpreter_.get(), delegate_), 1)
|
||||
<< "Expecting operation to be accelerated but cannot find a partition "
|
||||
"associated to the NNAPI delegate";
|
||||
}
|
||||
}
|
||||
|
||||
void SingleOpModel::ValidateAcceleration() {
|
||||
if (force_use_nnapi) {
|
||||
if (GetForceUseNnapi()) {
|
||||
ExpectOpAcceleratedWithNnapi(GetCurrentTestId());
|
||||
}
|
||||
}
|
||||
|
@ -515,8 +515,7 @@ class SingleOpModel {
|
||||
resolver_ = std::move(resolver);
|
||||
}
|
||||
|
||||
// Enables NNAPI delegate application during interpreter creation.
|
||||
static void SetForceUseNnapi(bool use_nnapi);
|
||||
// Indicate whether the test has the NNAPI delegate applied.
|
||||
static bool GetForceUseNnapi();
|
||||
int CountOpsExecutedByCpuKernel();
|
||||
|
||||
@ -769,6 +768,7 @@ class SingleOpModel {
|
||||
std::vector<flatbuffers::Offset<Tensor>> tensors_;
|
||||
std::vector<flatbuffers::Offset<Buffer>> buffers_;
|
||||
TfLiteDelegate* delegate_ = nullptr;
|
||||
int num_applied_delegates_ = 0;
|
||||
};
|
||||
|
||||
// Populate string tensors.
|
||||
|
Loading…
Reference in New Issue
Block a user