diff --git a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc index 3d4f8147912..e527796664f 100644 --- a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc +++ b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc @@ -233,6 +233,8 @@ BenchmarkParams BenchmarkTfLiteModel::DefaultParams() { default_params.AddParam("use_nnapi", BenchmarkParam::Create(false)); default_params.AddParam("use_legacy_nnapi", BenchmarkParam::Create(false)); + default_params.AddParam("nnapi_accelerator_name", + BenchmarkParam::Create("")); default_params.AddParam("use_gpu", BenchmarkParam::Create(false)); default_params.AddParam("allow_fp16", BenchmarkParam::Create(false)); default_params.AddParam( @@ -271,6 +273,9 @@ std::vector BenchmarkTfLiteModel::GetFlags() { "input layer shape"), CreateFlag("use_nnapi", ¶ms_, "use nnapi delegate api"), CreateFlag("use_legacy_nnapi", ¶ms_, "use legacy nnapi api"), + CreateFlag( + "nnapi_accelerator_name", ¶ms_, + "the name of the nnapi accelerator to use (requires Android Q+)"), CreateFlag("use_gpu", ¶ms_, "use gpu"), CreateFlag("allow_fp16", ¶ms_, "allow fp16"), CreateFlag("enable_op_profiling", ¶ms_, "enable op profiling"), @@ -291,6 +296,10 @@ void BenchmarkTfLiteModel::LogParams() { TFLITE_LOG(INFO) << "Use nnapi : [" << params_.Get("use_nnapi") << "]"; TFLITE_LOG(INFO) << "Use legacy nnapi : [" << params_.Get("use_legacy_nnapi") << "]"; + if (params_.HasParam("nnapi_accelerator_name")) { + TFLITE_LOG(INFO) << "nnapi accelerator name: [" + << params_.Get("nnapi_accelerator_name") << "]"; + } TFLITE_LOG(INFO) << "Use gpu : [" << params_.Get("use_gpu") << "]"; TFLITE_LOG(INFO) << "Allow fp16 : [" << params_.Get("allow_fp16") << "]"; @@ -506,12 +515,24 @@ BenchmarkTfLiteModel::TfLiteDelegatePtrMap BenchmarkTfLiteModel::GetDelegates() } } if (params_.Get("use_nnapi")) { - Interpreter::TfLiteDelegatePtr delegate = evaluation::CreateNNAPIDelegate(); + StatefulNnApiDelegate::Options options; + std::string accelerator_name; + if (params_.HasParam("nnapi_accelerator_name")) { + accelerator_name = params_.Get("nnapi_accelerator_name"); + options.accelerator_name = accelerator_name.c_str(); + } + Interpreter::TfLiteDelegatePtr delegate = + evaluation::CreateNNAPIDelegate(options); if (!delegate) { TFLITE_LOG(WARN) << "NNAPI acceleration is unsupported on this platform."; } else { delegates.emplace("NNAPI", std::move(delegate)); } + } else if (params_.HasParam("nnapi_accelerator_name")) { + TFLITE_LOG(WARN) + << "`--use_nnapi=true` must be set for the provided NNAPI accelerator (" + << params_.Get("nnapi_accelerator_name") + << ") to be used."; } return delegates; } diff --git a/tensorflow/lite/tools/evaluation/utils.cc b/tensorflow/lite/tools/evaluation/utils.cc index 6c5e5d94e8c..162acbabf7b 100644 --- a/tensorflow/lite/tools/evaluation/utils.cc +++ b/tensorflow/lite/tools/evaluation/utils.cc @@ -23,8 +23,6 @@ limitations under the License. #include #include -#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h" - namespace tflite { namespace evaluation { @@ -86,6 +84,18 @@ Interpreter::TfLiteDelegatePtr CreateNNAPIDelegate() { #endif // defined(__ANDROID__) } +Interpreter::TfLiteDelegatePtr CreateNNAPIDelegate( + StatefulNnApiDelegate::Options options) { +#if defined(__ANDROID__) + return Interpreter::TfLiteDelegatePtr( + new StatefulNnApiDelegate(options), [](TfLiteDelegate* delegate) { + delete reinterpret_cast(delegate); + }); +#else + return Interpreter::TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {}); +#endif // defined(__ANDROID__) +} + #if defined(__ANDROID__) Interpreter::TfLiteDelegatePtr CreateGPUDelegate( tflite::FlatBufferModel* model, TfLiteGpuDelegateOptions* options) { diff --git a/tensorflow/lite/tools/evaluation/utils.h b/tensorflow/lite/tools/evaluation/utils.h index b9fb92882fd..877a6493fbb 100644 --- a/tensorflow/lite/tools/evaluation/utils.h +++ b/tensorflow/lite/tools/evaluation/utils.h @@ -24,6 +24,7 @@ limitations under the License. #endif #include "tensorflow/lite/context.h" +#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h" #include "tensorflow/lite/model.h" namespace tflite { @@ -38,6 +39,9 @@ TfLiteStatus GetSortedFileNames(const std::string& directory, Interpreter::TfLiteDelegatePtr CreateNNAPIDelegate(); +Interpreter::TfLiteDelegatePtr CreateNNAPIDelegate( + StatefulNnApiDelegate::Options options); + Interpreter::TfLiteDelegatePtr CreateGPUDelegate(FlatBufferModel* model); #if defined(__ANDROID__) Interpreter::TfLiteDelegatePtr CreateGPUDelegate(