Allow specififying the nnapi accelerator in benchmark_model
Add a `--nnapi_accelerator_name` flag to benchmark_model, allowing explicit accelerator control on Android Q+. PiperOrigin-RevId: 258626781
This commit is contained in:
parent
75e3f9288e
commit
d88b16b545
@ -233,6 +233,8 @@ BenchmarkParams BenchmarkTfLiteModel::DefaultParams() {
|
||||
default_params.AddParam("use_nnapi", BenchmarkParam::Create<bool>(false));
|
||||
default_params.AddParam("use_legacy_nnapi",
|
||||
BenchmarkParam::Create<bool>(false));
|
||||
default_params.AddParam("nnapi_accelerator_name",
|
||||
BenchmarkParam::Create<std::string>(""));
|
||||
default_params.AddParam("use_gpu", BenchmarkParam::Create<bool>(false));
|
||||
default_params.AddParam("allow_fp16", BenchmarkParam::Create<bool>(false));
|
||||
default_params.AddParam(
|
||||
@ -271,6 +273,9 @@ std::vector<Flag> BenchmarkTfLiteModel::GetFlags() {
|
||||
"input layer shape"),
|
||||
CreateFlag<bool>("use_nnapi", ¶ms_, "use nnapi delegate api"),
|
||||
CreateFlag<bool>("use_legacy_nnapi", ¶ms_, "use legacy nnapi api"),
|
||||
CreateFlag<std::string>(
|
||||
"nnapi_accelerator_name", ¶ms_,
|
||||
"the name of the nnapi accelerator to use (requires Android Q+)"),
|
||||
CreateFlag<bool>("use_gpu", ¶ms_, "use gpu"),
|
||||
CreateFlag<bool>("allow_fp16", ¶ms_, "allow fp16"),
|
||||
CreateFlag<bool>("enable_op_profiling", ¶ms_, "enable op profiling"),
|
||||
@ -291,6 +296,10 @@ void BenchmarkTfLiteModel::LogParams() {
|
||||
TFLITE_LOG(INFO) << "Use nnapi : [" << params_.Get<bool>("use_nnapi") << "]";
|
||||
TFLITE_LOG(INFO) << "Use legacy nnapi : ["
|
||||
<< params_.Get<bool>("use_legacy_nnapi") << "]";
|
||||
if (params_.HasParam("nnapi_accelerator_name")) {
|
||||
TFLITE_LOG(INFO) << "nnapi accelerator name: ["
|
||||
<< params_.Get<string>("nnapi_accelerator_name") << "]";
|
||||
}
|
||||
TFLITE_LOG(INFO) << "Use gpu : [" << params_.Get<bool>("use_gpu") << "]";
|
||||
TFLITE_LOG(INFO) << "Allow fp16 : [" << params_.Get<bool>("allow_fp16")
|
||||
<< "]";
|
||||
@ -506,12 +515,24 @@ BenchmarkTfLiteModel::TfLiteDelegatePtrMap BenchmarkTfLiteModel::GetDelegates()
|
||||
}
|
||||
}
|
||||
if (params_.Get<bool>("use_nnapi")) {
|
||||
Interpreter::TfLiteDelegatePtr delegate = evaluation::CreateNNAPIDelegate();
|
||||
StatefulNnApiDelegate::Options options;
|
||||
std::string accelerator_name;
|
||||
if (params_.HasParam("nnapi_accelerator_name")) {
|
||||
accelerator_name = params_.Get<std::string>("nnapi_accelerator_name");
|
||||
options.accelerator_name = accelerator_name.c_str();
|
||||
}
|
||||
Interpreter::TfLiteDelegatePtr delegate =
|
||||
evaluation::CreateNNAPIDelegate(options);
|
||||
if (!delegate) {
|
||||
TFLITE_LOG(WARN) << "NNAPI acceleration is unsupported on this platform.";
|
||||
} else {
|
||||
delegates.emplace("NNAPI", std::move(delegate));
|
||||
}
|
||||
} else if (params_.HasParam("nnapi_accelerator_name")) {
|
||||
TFLITE_LOG(WARN)
|
||||
<< "`--use_nnapi=true` must be set for the provided NNAPI accelerator ("
|
||||
<< params_.Get<std::string>("nnapi_accelerator_name")
|
||||
<< ") to be used.";
|
||||
}
|
||||
return delegates;
|
||||
}
|
||||
|
@ -23,8 +23,6 @@ limitations under the License.
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace evaluation {
|
||||
|
||||
@ -86,6 +84,18 @@ Interpreter::TfLiteDelegatePtr CreateNNAPIDelegate() {
|
||||
#endif // defined(__ANDROID__)
|
||||
}
|
||||
|
||||
Interpreter::TfLiteDelegatePtr CreateNNAPIDelegate(
|
||||
StatefulNnApiDelegate::Options options) {
|
||||
#if defined(__ANDROID__)
|
||||
return Interpreter::TfLiteDelegatePtr(
|
||||
new StatefulNnApiDelegate(options), [](TfLiteDelegate* delegate) {
|
||||
delete reinterpret_cast<StatefulNnApiDelegate*>(delegate);
|
||||
});
|
||||
#else
|
||||
return Interpreter::TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {});
|
||||
#endif // defined(__ANDROID__)
|
||||
}
|
||||
|
||||
#if defined(__ANDROID__)
|
||||
Interpreter::TfLiteDelegatePtr CreateGPUDelegate(
|
||||
tflite::FlatBufferModel* model, TfLiteGpuDelegateOptions* options) {
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#endif
|
||||
|
||||
#include "tensorflow/lite/context.h"
|
||||
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
|
||||
namespace tflite {
|
||||
@ -38,6 +39,9 @@ TfLiteStatus GetSortedFileNames(const std::string& directory,
|
||||
|
||||
Interpreter::TfLiteDelegatePtr CreateNNAPIDelegate();
|
||||
|
||||
Interpreter::TfLiteDelegatePtr CreateNNAPIDelegate(
|
||||
StatefulNnApiDelegate::Options options);
|
||||
|
||||
Interpreter::TfLiteDelegatePtr CreateGPUDelegate(FlatBufferModel* model);
|
||||
#if defined(__ANDROID__)
|
||||
Interpreter::TfLiteDelegatePtr CreateGPUDelegate(
|
||||
|
Loading…
Reference in New Issue
Block a user