STT-tensorflow/tensorflow/lite/tools/delegates/nnapi_delegate_provider.cc
Robert David f1e137db12 Make lite/tools/evaluation:utils depend on the NNAPI delegate only when building for Android.
PiperOrigin-RevId: 313615943
Change-Id: Idf3b05cfea63c9578c726e5ed7b5afacd9e1a495
2020-05-28 10:42:55 -07:00

172 lines
7.0 KiB
C++

/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <string>
#include "tensorflow/lite/tools/delegates/delegate_provider.h"
#include "tensorflow/lite/tools/evaluation/utils.h"
#if defined(__ANDROID__)
#include "tensorflow/lite/nnapi/nnapi_util.h"
#endif
namespace tflite {
namespace tools {
class NnapiDelegateProvider : public DelegateProvider {
public:
NnapiDelegateProvider() {
#if defined(__ANDROID__)
default_params_.AddParam("use_nnapi", ToolParam::Create<bool>(false));
default_params_.AddParam("nnapi_execution_preference",
ToolParam::Create<std::string>(""));
default_params_.AddParam("nnapi_accelerator_name",
ToolParam::Create<std::string>(""));
default_params_.AddParam("disable_nnapi_cpu",
ToolParam::Create<bool>(false));
default_params_.AddParam("nnapi_allow_fp16",
ToolParam::Create<bool>(false));
#endif
}
std::vector<Flag> CreateFlags(ToolParams* params) const final;
void LogParams(const ToolParams& params) const final;
TfLiteDelegatePtr CreateTfLiteDelegate(const ToolParams& params) const final;
std::string GetName() const final { return "NNAPI"; }
};
REGISTER_DELEGATE_PROVIDER(NnapiDelegateProvider);
std::vector<Flag> NnapiDelegateProvider::CreateFlags(ToolParams* params) const {
std::vector<Flag> flags = {
#if defined(__ANDROID__)
CreateFlag<bool>("use_nnapi", params, "use nnapi delegate api"),
CreateFlag<std::string>("nnapi_execution_preference", params,
"execution preference for nnapi delegate. Should "
"be one of the following: fast_single_answer, "
"sustained_speed, low_power, undefined"),
CreateFlag<std::string>(
"nnapi_accelerator_name", params,
"the name of the nnapi accelerator to use (requires Android Q+)"),
CreateFlag<bool>("disable_nnapi_cpu", params,
"Disable the NNAPI CPU device"),
CreateFlag<bool>("nnapi_allow_fp16", params,
"Allow fp32 computation to be run in fp16")
#endif
};
return flags;
}
void NnapiDelegateProvider::LogParams(const ToolParams& params) const {
#if defined(__ANDROID__)
TFLITE_LOG(INFO) << "Use nnapi : [" << params.Get<bool>("use_nnapi") << "]";
if (params.Get<bool>("use_nnapi")) {
if (!params.Get<std::string>("nnapi_execution_preference").empty()) {
TFLITE_LOG(INFO) << "nnapi execution preference: ["
<< params.Get<std::string>("nnapi_execution_preference")
<< "]";
}
std::string log_string = "nnapi accelerator name: [" +
params.Get<std::string>("nnapi_accelerator_name") +
"]";
std::string string_device_names_list = nnapi::GetStringDeviceNamesList();
// Print available devices when possible
if (!string_device_names_list.empty()) {
log_string += " (Available: " + string_device_names_list + ")";
}
TFLITE_LOG(INFO) << log_string;
if (params.Get<bool>("disable_nnapi_cpu")) {
TFLITE_LOG(INFO) << "disable_nnapi_cpu: ["
<< params.Get<bool>("disable_nnapi_cpu") << "]";
}
if (params.Get<bool>("nnapi_allow_fp16")) {
TFLITE_LOG(INFO) << "Allow fp16 in NNAPI: ["
<< params.Get<bool>("nnapi_allow_fp16") << "]";
}
}
#endif
}
TfLiteDelegatePtr NnapiDelegateProvider::CreateTfLiteDelegate(
const ToolParams& params) const {
TfLiteDelegatePtr delegate(nullptr, [](TfLiteDelegate*) {});
#if defined(__ANDROID__)
if (params.Get<bool>("use_nnapi")) {
StatefulNnApiDelegate::Options options;
std::string accelerator_name =
params.Get<std::string>("nnapi_accelerator_name");
if (!accelerator_name.empty()) {
options.accelerator_name = accelerator_name.c_str();
} else if (params.Get<bool>("disable_nnapi_cpu")) {
options.disallow_nnapi_cpu = true;
}
if (params.Get<bool>("nnapi_allow_fp16")) {
options.allow_fp16 = true;
}
std::string string_execution_preference =
params.Get<std::string>("nnapi_execution_preference");
// Only set execution preference if user explicitly passes one. Otherwise,
// leave it as whatever NNAPI has as the default.
if (!string_execution_preference.empty()) {
tflite::StatefulNnApiDelegate::Options::ExecutionPreference
execution_preference =
tflite::StatefulNnApiDelegate::Options::kUndefined;
if (string_execution_preference == "low_power") {
execution_preference =
tflite::StatefulNnApiDelegate::Options::kLowPower;
} else if (string_execution_preference == "sustained_speed") {
execution_preference =
tflite::StatefulNnApiDelegate::Options::kSustainedSpeed;
} else if (string_execution_preference == "fast_single_answer") {
execution_preference =
tflite::StatefulNnApiDelegate::Options::kFastSingleAnswer;
} else if (string_execution_preference == "undefined") {
execution_preference =
tflite::StatefulNnApiDelegate::Options::kUndefined;
} else {
TFLITE_LOG(WARN) << "The provided value ("
<< string_execution_preference
<< ") is not a valid nnapi execution preference.";
}
options.execution_preference = execution_preference;
}
int max_delegated_partitions = params.Get<int>("max_delegated_partitions");
if (max_delegated_partitions > 0) {
options.max_number_delegated_partitions = max_delegated_partitions;
}
delegate = evaluation::CreateNNAPIDelegate(options);
if (!delegate.get()) {
TFLITE_LOG(WARN) << "NNAPI acceleration is unsupported on this platform.";
}
} else if (!params.Get<std::string>("nnapi_accelerator_name").empty()) {
TFLITE_LOG(WARN)
<< "`--use_nnapi=true` must be set for the provided NNAPI accelerator ("
<< params.Get<std::string>("nnapi_accelerator_name") << ") to be used.";
} else if (!params.Get<std::string>("nnapi_execution_preference").empty()) {
TFLITE_LOG(WARN) << "`--use_nnapi=true` must be set for the provided NNAPI "
"execution preference ("
<< params.Get<std::string>("nnapi_execution_preference")
<< ") to be used.";
}
#endif
return delegate;
}
} // namespace tools
} // namespace tflite