Use the global delegate providers to drive the NNAPI delegate test.

PiperOrigin-RevId: 321271726
Change-Id: If50f8ca0bd712fbab9c041bf2db416b00820a251
This commit is contained in:
Chao Mei 2020-07-14 17:47:34 -07:00 committed by TensorFlower Gardener
parent 2256faa6b7
commit d956c282fb
3 changed files with 35 additions and 50 deletions

View File

@ -26,9 +26,13 @@ void InitKernelTest(int* argc, char** argv) {
tflite::KernelTestDelegateProviders::Get();
delegate_providers->InitFromCmdlineArgs(argc, const_cast<const char**>(argv));
// TODO(b/160764491): remove the special handling of NNAPI delegate test.
tflite::SingleOpModel::SetForceUseNnapi(
delegate_providers->ConstParams().Get<bool>("use_nnapi"));
if (delegate_providers->ConstParams().Get<bool>("use_nnapi")) {
// In Android Q, the NNAPI delegate avoids delegation if the only device
// is the reference CPU. However, for testing purposes, we still want
// delegation coverage, so force use of this reference path.
delegate_providers->MutableParams()->Set<std::string>(
"nnapi_accelerator_name", "nnapi-reference");
}
}
} // namespace

View File

@ -55,26 +55,6 @@ namespace tflite {
using ::testing::FloatNear;
using ::testing::Matcher;
namespace {
// Whether to enable (global) use of NNAPI. Note that this will typically
// be set via a command-line flag.
static bool force_use_nnapi = false;
TfLiteDelegate* TestNnApiDelegate() {
static TfLiteDelegate* delegate = [] {
StatefulNnApiDelegate::Options options;
// In Android Q, the NNAPI delegate avoids delegation if the only device
// is the reference CPU. However, for testing purposes, we still want
// delegation coverage, so force use of this reference path.
options.accelerator_name = "nnapi-reference";
return new StatefulNnApiDelegate(options);
}();
return delegate;
}
} // namespace
std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values,
float max_abs_error) {
std::vector<Matcher<float>> matchers;
@ -221,26 +201,22 @@ void SingleOpModel::BuildInterpreter(std::vector<std::vector<int>> input_shapes,
}
TfLiteStatus SingleOpModel::ApplyDelegate() {
auto* delegate_providers = tflite::KernelTestDelegateProviders::Get();
if (force_use_nnapi) {
delegate_ = TestNnApiDelegate();
// As we currently have special handling of nnapi delegate in kernel tests,
// we turn off the nnapi delegate provider to avoid re-applying it later.
// TODO(b/160764491): remove this special handling for NNAPI delegate test.
delegate_providers->MutableParams()->Set<bool>("use_nnapi", false);
}
if (delegate_) {
TFLITE_LOG(WARN) << "Having a manually-set TfLite delegate, and bypassing "
"KernelTestDelegateProviders";
return interpreter_->ModifyGraphWithDelegate(delegate_);
}
for (auto& one : delegate_providers->CreateAllDelegates()) {
TF_LITE_ENSURE_STATUS(
interpreter_->ModifyGraphWithDelegate(std::move(one)));
TF_LITE_ENSURE_STATUS(interpreter_->ModifyGraphWithDelegate(delegate_));
++num_applied_delegates_;
} else {
auto* delegate_providers = tflite::KernelTestDelegateProviders::Get();
for (auto& one : delegate_providers->CreateAllDelegates()) {
// The raw ptr always points to the actual TfLiteDegate object.
auto* delegate_raw_ptr = one.get();
TF_LITE_ENSURE_STATUS(
interpreter_->ModifyGraphWithDelegate(std::move(one)));
// Note: 'delegate_' is always set to the last successfully applied one.
delegate_ = delegate_raw_ptr;
++num_applied_delegates_;
}
}
return kTfLiteOk;
}
@ -257,13 +233,11 @@ void SingleOpModel::BuildInterpreter(
}
// static
void SingleOpModel::SetForceUseNnapi(bool use_nnapi) {
force_use_nnapi = use_nnapi;
bool SingleOpModel::GetForceUseNnapi() {
return tflite::KernelTestDelegateProviders::Get()->ConstParams().Get<bool>(
"use_nnapi");
}
// static
bool SingleOpModel::GetForceUseNnapi() { return force_use_nnapi; }
int32_t SingleOpModel::GetTensorSize(int index) const {
TfLiteTensor* t = interpreter_->tensor(index);
CHECK(t);
@ -342,20 +316,27 @@ void SingleOpModel::ExpectOpAcceleratedWithNnapi(const std::string& test_id) {
return;
}
// If we have multiple delegates applied, we would skip this check at the
// moment.
if (num_applied_delegates_ > 1) {
TFLITE_LOG(WARN) << "Skipping ExpectOpAcceleratedWithNnapi as "
<< num_applied_delegates_
<< " delegates have been successfully applied.";
return;
}
TFLITE_LOG(INFO) << "Validating acceleration";
const NnApi* nnapi = NnApiImplementation();
if (nnapi && nnapi->nnapi_exists &&
nnapi->android_sdk_version >=
validation_params.value().MinAndroidSdkVersion()) {
EXPECT_EQ(
CountPartitionsDelegatedTo(interpreter_.get(), TestNnApiDelegate()), 1)
EXPECT_EQ(CountPartitionsDelegatedTo(interpreter_.get(), delegate_), 1)
<< "Expecting operation to be accelerated but cannot find a partition "
"associated to the NNAPI delegate";
}
}
void SingleOpModel::ValidateAcceleration() {
if (force_use_nnapi) {
if (GetForceUseNnapi()) {
ExpectOpAcceleratedWithNnapi(GetCurrentTestId());
}
}

View File

@ -515,8 +515,7 @@ class SingleOpModel {
resolver_ = std::move(resolver);
}
// Enables NNAPI delegate application during interpreter creation.
static void SetForceUseNnapi(bool use_nnapi);
// Indicate whether the test has the NNAPI delegate applied.
static bool GetForceUseNnapi();
int CountOpsExecutedByCpuKernel();
@ -769,6 +768,7 @@ class SingleOpModel {
std::vector<flatbuffers::Offset<Tensor>> tensors_;
std::vector<flatbuffers::Offset<Buffer>> buffers_;
TfLiteDelegate* delegate_ = nullptr;
int num_applied_delegates_ = 0;
};
// Populate string tensors.