Allow specififying the nnapi accelerator in benchmark_model
Add a `--nnapi_accelerator_name` flag to benchmark_model, allowing explicit accelerator control on Android Q+. PiperOrigin-RevId: 258626781
This commit is contained in:
parent
75e3f9288e
commit
d88b16b545
@ -233,6 +233,8 @@ BenchmarkParams BenchmarkTfLiteModel::DefaultParams() {
|
|||||||
default_params.AddParam("use_nnapi", BenchmarkParam::Create<bool>(false));
|
default_params.AddParam("use_nnapi", BenchmarkParam::Create<bool>(false));
|
||||||
default_params.AddParam("use_legacy_nnapi",
|
default_params.AddParam("use_legacy_nnapi",
|
||||||
BenchmarkParam::Create<bool>(false));
|
BenchmarkParam::Create<bool>(false));
|
||||||
|
default_params.AddParam("nnapi_accelerator_name",
|
||||||
|
BenchmarkParam::Create<std::string>(""));
|
||||||
default_params.AddParam("use_gpu", BenchmarkParam::Create<bool>(false));
|
default_params.AddParam("use_gpu", BenchmarkParam::Create<bool>(false));
|
||||||
default_params.AddParam("allow_fp16", BenchmarkParam::Create<bool>(false));
|
default_params.AddParam("allow_fp16", BenchmarkParam::Create<bool>(false));
|
||||||
default_params.AddParam(
|
default_params.AddParam(
|
||||||
@ -271,6 +273,9 @@ std::vector<Flag> BenchmarkTfLiteModel::GetFlags() {
|
|||||||
"input layer shape"),
|
"input layer shape"),
|
||||||
CreateFlag<bool>("use_nnapi", ¶ms_, "use nnapi delegate api"),
|
CreateFlag<bool>("use_nnapi", ¶ms_, "use nnapi delegate api"),
|
||||||
CreateFlag<bool>("use_legacy_nnapi", ¶ms_, "use legacy nnapi api"),
|
CreateFlag<bool>("use_legacy_nnapi", ¶ms_, "use legacy nnapi api"),
|
||||||
|
CreateFlag<std::string>(
|
||||||
|
"nnapi_accelerator_name", ¶ms_,
|
||||||
|
"the name of the nnapi accelerator to use (requires Android Q+)"),
|
||||||
CreateFlag<bool>("use_gpu", ¶ms_, "use gpu"),
|
CreateFlag<bool>("use_gpu", ¶ms_, "use gpu"),
|
||||||
CreateFlag<bool>("allow_fp16", ¶ms_, "allow fp16"),
|
CreateFlag<bool>("allow_fp16", ¶ms_, "allow fp16"),
|
||||||
CreateFlag<bool>("enable_op_profiling", ¶ms_, "enable op profiling"),
|
CreateFlag<bool>("enable_op_profiling", ¶ms_, "enable op profiling"),
|
||||||
@ -291,6 +296,10 @@ void BenchmarkTfLiteModel::LogParams() {
|
|||||||
TFLITE_LOG(INFO) << "Use nnapi : [" << params_.Get<bool>("use_nnapi") << "]";
|
TFLITE_LOG(INFO) << "Use nnapi : [" << params_.Get<bool>("use_nnapi") << "]";
|
||||||
TFLITE_LOG(INFO) << "Use legacy nnapi : ["
|
TFLITE_LOG(INFO) << "Use legacy nnapi : ["
|
||||||
<< params_.Get<bool>("use_legacy_nnapi") << "]";
|
<< params_.Get<bool>("use_legacy_nnapi") << "]";
|
||||||
|
if (params_.HasParam("nnapi_accelerator_name")) {
|
||||||
|
TFLITE_LOG(INFO) << "nnapi accelerator name: ["
|
||||||
|
<< params_.Get<string>("nnapi_accelerator_name") << "]";
|
||||||
|
}
|
||||||
TFLITE_LOG(INFO) << "Use gpu : [" << params_.Get<bool>("use_gpu") << "]";
|
TFLITE_LOG(INFO) << "Use gpu : [" << params_.Get<bool>("use_gpu") << "]";
|
||||||
TFLITE_LOG(INFO) << "Allow fp16 : [" << params_.Get<bool>("allow_fp16")
|
TFLITE_LOG(INFO) << "Allow fp16 : [" << params_.Get<bool>("allow_fp16")
|
||||||
<< "]";
|
<< "]";
|
||||||
@ -506,12 +515,24 @@ BenchmarkTfLiteModel::TfLiteDelegatePtrMap BenchmarkTfLiteModel::GetDelegates()
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (params_.Get<bool>("use_nnapi")) {
|
if (params_.Get<bool>("use_nnapi")) {
|
||||||
Interpreter::TfLiteDelegatePtr delegate = evaluation::CreateNNAPIDelegate();
|
StatefulNnApiDelegate::Options options;
|
||||||
|
std::string accelerator_name;
|
||||||
|
if (params_.HasParam("nnapi_accelerator_name")) {
|
||||||
|
accelerator_name = params_.Get<std::string>("nnapi_accelerator_name");
|
||||||
|
options.accelerator_name = accelerator_name.c_str();
|
||||||
|
}
|
||||||
|
Interpreter::TfLiteDelegatePtr delegate =
|
||||||
|
evaluation::CreateNNAPIDelegate(options);
|
||||||
if (!delegate) {
|
if (!delegate) {
|
||||||
TFLITE_LOG(WARN) << "NNAPI acceleration is unsupported on this platform.";
|
TFLITE_LOG(WARN) << "NNAPI acceleration is unsupported on this platform.";
|
||||||
} else {
|
} else {
|
||||||
delegates.emplace("NNAPI", std::move(delegate));
|
delegates.emplace("NNAPI", std::move(delegate));
|
||||||
}
|
}
|
||||||
|
} else if (params_.HasParam("nnapi_accelerator_name")) {
|
||||||
|
TFLITE_LOG(WARN)
|
||||||
|
<< "`--use_nnapi=true` must be set for the provided NNAPI accelerator ("
|
||||||
|
<< params_.Get<std::string>("nnapi_accelerator_name")
|
||||||
|
<< ") to be used.";
|
||||||
}
|
}
|
||||||
return delegates;
|
return delegates;
|
||||||
}
|
}
|
||||||
|
@ -23,8 +23,6 @@ limitations under the License.
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
|
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace evaluation {
|
namespace evaluation {
|
||||||
|
|
||||||
@ -86,6 +84,18 @@ Interpreter::TfLiteDelegatePtr CreateNNAPIDelegate() {
|
|||||||
#endif // defined(__ANDROID__)
|
#endif // defined(__ANDROID__)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Interpreter::TfLiteDelegatePtr CreateNNAPIDelegate(
|
||||||
|
StatefulNnApiDelegate::Options options) {
|
||||||
|
#if defined(__ANDROID__)
|
||||||
|
return Interpreter::TfLiteDelegatePtr(
|
||||||
|
new StatefulNnApiDelegate(options), [](TfLiteDelegate* delegate) {
|
||||||
|
delete reinterpret_cast<StatefulNnApiDelegate*>(delegate);
|
||||||
|
});
|
||||||
|
#else
|
||||||
|
return Interpreter::TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {});
|
||||||
|
#endif // defined(__ANDROID__)
|
||||||
|
}
|
||||||
|
|
||||||
#if defined(__ANDROID__)
|
#if defined(__ANDROID__)
|
||||||
Interpreter::TfLiteDelegatePtr CreateGPUDelegate(
|
Interpreter::TfLiteDelegatePtr CreateGPUDelegate(
|
||||||
tflite::FlatBufferModel* model, TfLiteGpuDelegateOptions* options) {
|
tflite::FlatBufferModel* model, TfLiteGpuDelegateOptions* options) {
|
||||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include "tensorflow/lite/context.h"
|
#include "tensorflow/lite/context.h"
|
||||||
|
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
|
||||||
#include "tensorflow/lite/model.h"
|
#include "tensorflow/lite/model.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
@ -38,6 +39,9 @@ TfLiteStatus GetSortedFileNames(const std::string& directory,
|
|||||||
|
|
||||||
Interpreter::TfLiteDelegatePtr CreateNNAPIDelegate();
|
Interpreter::TfLiteDelegatePtr CreateNNAPIDelegate();
|
||||||
|
|
||||||
|
Interpreter::TfLiteDelegatePtr CreateNNAPIDelegate(
|
||||||
|
StatefulNnApiDelegate::Options options);
|
||||||
|
|
||||||
Interpreter::TfLiteDelegatePtr CreateGPUDelegate(FlatBufferModel* model);
|
Interpreter::TfLiteDelegatePtr CreateGPUDelegate(FlatBufferModel* model);
|
||||||
#if defined(__ANDROID__)
|
#if defined(__ANDROID__)
|
||||||
Interpreter::TfLiteDelegatePtr CreateGPUDelegate(
|
Interpreter::TfLiteDelegatePtr CreateGPUDelegate(
|
||||||
|
Loading…
Reference in New Issue
Block a user