Allow specififying the nnapi accelerator in benchmark_model

Add a `--nnapi_accelerator_name` flag to benchmark_model, allowing
explicit accelerator control on Android Q+.

PiperOrigin-RevId: 258626781
This commit is contained in:
Jared Duke 2019-07-17 13:08:01 -07:00 committed by TensorFlower Gardener
parent 75e3f9288e
commit d88b16b545
3 changed files with 38 additions and 3 deletions

View File

@ -233,6 +233,8 @@ BenchmarkParams BenchmarkTfLiteModel::DefaultParams() {
default_params.AddParam("use_nnapi", BenchmarkParam::Create<bool>(false));
default_params.AddParam("use_legacy_nnapi",
BenchmarkParam::Create<bool>(false));
default_params.AddParam("nnapi_accelerator_name",
BenchmarkParam::Create<std::string>(""));
default_params.AddParam("use_gpu", BenchmarkParam::Create<bool>(false));
default_params.AddParam("allow_fp16", BenchmarkParam::Create<bool>(false));
default_params.AddParam(
@ -271,6 +273,9 @@ std::vector<Flag> BenchmarkTfLiteModel::GetFlags() {
"input layer shape"),
CreateFlag<bool>("use_nnapi", &params_, "use nnapi delegate api"),
CreateFlag<bool>("use_legacy_nnapi", &params_, "use legacy nnapi api"),
CreateFlag<std::string>(
"nnapi_accelerator_name", &params_,
"the name of the nnapi accelerator to use (requires Android Q+)"),
CreateFlag<bool>("use_gpu", &params_, "use gpu"),
CreateFlag<bool>("allow_fp16", &params_, "allow fp16"),
CreateFlag<bool>("enable_op_profiling", &params_, "enable op profiling"),
@ -291,6 +296,10 @@ void BenchmarkTfLiteModel::LogParams() {
TFLITE_LOG(INFO) << "Use nnapi : [" << params_.Get<bool>("use_nnapi") << "]";
TFLITE_LOG(INFO) << "Use legacy nnapi : ["
<< params_.Get<bool>("use_legacy_nnapi") << "]";
if (params_.HasParam("nnapi_accelerator_name")) {
TFLITE_LOG(INFO) << "nnapi accelerator name: ["
<< params_.Get<string>("nnapi_accelerator_name") << "]";
}
TFLITE_LOG(INFO) << "Use gpu : [" << params_.Get<bool>("use_gpu") << "]";
TFLITE_LOG(INFO) << "Allow fp16 : [" << params_.Get<bool>("allow_fp16")
<< "]";
@ -506,12 +515,24 @@ BenchmarkTfLiteModel::TfLiteDelegatePtrMap BenchmarkTfLiteModel::GetDelegates()
}
}
if (params_.Get<bool>("use_nnapi")) {
Interpreter::TfLiteDelegatePtr delegate = evaluation::CreateNNAPIDelegate();
StatefulNnApiDelegate::Options options;
std::string accelerator_name;
if (params_.HasParam("nnapi_accelerator_name")) {
accelerator_name = params_.Get<std::string>("nnapi_accelerator_name");
options.accelerator_name = accelerator_name.c_str();
}
Interpreter::TfLiteDelegatePtr delegate =
evaluation::CreateNNAPIDelegate(options);
if (!delegate) {
TFLITE_LOG(WARN) << "NNAPI acceleration is unsupported on this platform.";
} else {
delegates.emplace("NNAPI", std::move(delegate));
}
} else if (params_.HasParam("nnapi_accelerator_name")) {
TFLITE_LOG(WARN)
<< "`--use_nnapi=true` must be set for the provided NNAPI accelerator ("
<< params_.Get<std::string>("nnapi_accelerator_name")
<< ") to be used.";
}
return delegates;
}

View File

@ -23,8 +23,6 @@ limitations under the License.
#include <memory>
#include <string>
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
namespace tflite {
namespace evaluation {
@ -86,6 +84,18 @@ Interpreter::TfLiteDelegatePtr CreateNNAPIDelegate() {
#endif // defined(__ANDROID__)
}
Interpreter::TfLiteDelegatePtr CreateNNAPIDelegate(
StatefulNnApiDelegate::Options options) {
#if defined(__ANDROID__)
return Interpreter::TfLiteDelegatePtr(
new StatefulNnApiDelegate(options), [](TfLiteDelegate* delegate) {
delete reinterpret_cast<StatefulNnApiDelegate*>(delegate);
});
#else
return Interpreter::TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {});
#endif // defined(__ANDROID__)
}
#if defined(__ANDROID__)
Interpreter::TfLiteDelegatePtr CreateGPUDelegate(
tflite::FlatBufferModel* model, TfLiteGpuDelegateOptions* options) {

View File

@ -24,6 +24,7 @@ limitations under the License.
#endif
#include "tensorflow/lite/context.h"
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
#include "tensorflow/lite/model.h"
namespace tflite {
@ -38,6 +39,9 @@ TfLiteStatus GetSortedFileNames(const std::string& directory,
Interpreter::TfLiteDelegatePtr CreateNNAPIDelegate();
Interpreter::TfLiteDelegatePtr CreateNNAPIDelegate(
StatefulNnApiDelegate::Options options);
Interpreter::TfLiteDelegatePtr CreateGPUDelegate(FlatBufferModel* model);
#if defined(__ANDROID__)
Interpreter::TfLiteDelegatePtr CreateGPUDelegate(