Allow specififying the nnapi accelerator in benchmark_model

Add a `--nnapi_accelerator_name` flag to benchmark_model, allowing
explicit accelerator control on Android Q+.

PiperOrigin-RevId: 258626781
This commit is contained in:
Jared Duke 2019-07-17 13:08:01 -07:00 committed by TensorFlower Gardener
parent 75e3f9288e
commit d88b16b545
3 changed files with 38 additions and 3 deletions

View File

@ -233,6 +233,8 @@ BenchmarkParams BenchmarkTfLiteModel::DefaultParams() {
default_params.AddParam("use_nnapi", BenchmarkParam::Create<bool>(false)); default_params.AddParam("use_nnapi", BenchmarkParam::Create<bool>(false));
default_params.AddParam("use_legacy_nnapi", default_params.AddParam("use_legacy_nnapi",
BenchmarkParam::Create<bool>(false)); BenchmarkParam::Create<bool>(false));
default_params.AddParam("nnapi_accelerator_name",
BenchmarkParam::Create<std::string>(""));
default_params.AddParam("use_gpu", BenchmarkParam::Create<bool>(false)); default_params.AddParam("use_gpu", BenchmarkParam::Create<bool>(false));
default_params.AddParam("allow_fp16", BenchmarkParam::Create<bool>(false)); default_params.AddParam("allow_fp16", BenchmarkParam::Create<bool>(false));
default_params.AddParam( default_params.AddParam(
@ -271,6 +273,9 @@ std::vector<Flag> BenchmarkTfLiteModel::GetFlags() {
"input layer shape"), "input layer shape"),
CreateFlag<bool>("use_nnapi", &params_, "use nnapi delegate api"), CreateFlag<bool>("use_nnapi", &params_, "use nnapi delegate api"),
CreateFlag<bool>("use_legacy_nnapi", &params_, "use legacy nnapi api"), CreateFlag<bool>("use_legacy_nnapi", &params_, "use legacy nnapi api"),
CreateFlag<std::string>(
"nnapi_accelerator_name", &params_,
"the name of the nnapi accelerator to use (requires Android Q+)"),
CreateFlag<bool>("use_gpu", &params_, "use gpu"), CreateFlag<bool>("use_gpu", &params_, "use gpu"),
CreateFlag<bool>("allow_fp16", &params_, "allow fp16"), CreateFlag<bool>("allow_fp16", &params_, "allow fp16"),
CreateFlag<bool>("enable_op_profiling", &params_, "enable op profiling"), CreateFlag<bool>("enable_op_profiling", &params_, "enable op profiling"),
@ -291,6 +296,10 @@ void BenchmarkTfLiteModel::LogParams() {
TFLITE_LOG(INFO) << "Use nnapi : [" << params_.Get<bool>("use_nnapi") << "]"; TFLITE_LOG(INFO) << "Use nnapi : [" << params_.Get<bool>("use_nnapi") << "]";
TFLITE_LOG(INFO) << "Use legacy nnapi : [" TFLITE_LOG(INFO) << "Use legacy nnapi : ["
<< params_.Get<bool>("use_legacy_nnapi") << "]"; << params_.Get<bool>("use_legacy_nnapi") << "]";
if (params_.HasParam("nnapi_accelerator_name")) {
TFLITE_LOG(INFO) << "nnapi accelerator name: ["
<< params_.Get<string>("nnapi_accelerator_name") << "]";
}
TFLITE_LOG(INFO) << "Use gpu : [" << params_.Get<bool>("use_gpu") << "]"; TFLITE_LOG(INFO) << "Use gpu : [" << params_.Get<bool>("use_gpu") << "]";
TFLITE_LOG(INFO) << "Allow fp16 : [" << params_.Get<bool>("allow_fp16") TFLITE_LOG(INFO) << "Allow fp16 : [" << params_.Get<bool>("allow_fp16")
<< "]"; << "]";
@ -506,12 +515,24 @@ BenchmarkTfLiteModel::TfLiteDelegatePtrMap BenchmarkTfLiteModel::GetDelegates()
} }
} }
if (params_.Get<bool>("use_nnapi")) { if (params_.Get<bool>("use_nnapi")) {
Interpreter::TfLiteDelegatePtr delegate = evaluation::CreateNNAPIDelegate(); StatefulNnApiDelegate::Options options;
std::string accelerator_name;
if (params_.HasParam("nnapi_accelerator_name")) {
accelerator_name = params_.Get<std::string>("nnapi_accelerator_name");
options.accelerator_name = accelerator_name.c_str();
}
Interpreter::TfLiteDelegatePtr delegate =
evaluation::CreateNNAPIDelegate(options);
if (!delegate) { if (!delegate) {
TFLITE_LOG(WARN) << "NNAPI acceleration is unsupported on this platform."; TFLITE_LOG(WARN) << "NNAPI acceleration is unsupported on this platform.";
} else { } else {
delegates.emplace("NNAPI", std::move(delegate)); delegates.emplace("NNAPI", std::move(delegate));
} }
} else if (params_.HasParam("nnapi_accelerator_name")) {
TFLITE_LOG(WARN)
<< "`--use_nnapi=true` must be set for the provided NNAPI accelerator ("
<< params_.Get<std::string>("nnapi_accelerator_name")
<< ") to be used.";
} }
return delegates; return delegates;
} }

View File

@ -23,8 +23,6 @@ limitations under the License.
#include <memory> #include <memory>
#include <string> #include <string>
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
namespace tflite { namespace tflite {
namespace evaluation { namespace evaluation {
@ -86,6 +84,18 @@ Interpreter::TfLiteDelegatePtr CreateNNAPIDelegate() {
#endif // defined(__ANDROID__) #endif // defined(__ANDROID__)
} }
Interpreter::TfLiteDelegatePtr CreateNNAPIDelegate(
StatefulNnApiDelegate::Options options) {
#if defined(__ANDROID__)
return Interpreter::TfLiteDelegatePtr(
new StatefulNnApiDelegate(options), [](TfLiteDelegate* delegate) {
delete reinterpret_cast<StatefulNnApiDelegate*>(delegate);
});
#else
return Interpreter::TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {});
#endif // defined(__ANDROID__)
}
#if defined(__ANDROID__) #if defined(__ANDROID__)
Interpreter::TfLiteDelegatePtr CreateGPUDelegate( Interpreter::TfLiteDelegatePtr CreateGPUDelegate(
tflite::FlatBufferModel* model, TfLiteGpuDelegateOptions* options) { tflite::FlatBufferModel* model, TfLiteGpuDelegateOptions* options) {

View File

@ -24,6 +24,7 @@ limitations under the License.
#endif #endif
#include "tensorflow/lite/context.h" #include "tensorflow/lite/context.h"
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
#include "tensorflow/lite/model.h" #include "tensorflow/lite/model.h"
namespace tflite { namespace tflite {
@ -38,6 +39,9 @@ TfLiteStatus GetSortedFileNames(const std::string& directory,
Interpreter::TfLiteDelegatePtr CreateNNAPIDelegate(); Interpreter::TfLiteDelegatePtr CreateNNAPIDelegate();
Interpreter::TfLiteDelegatePtr CreateNNAPIDelegate(
StatefulNnApiDelegate::Options options);
Interpreter::TfLiteDelegatePtr CreateGPUDelegate(FlatBufferModel* model); Interpreter::TfLiteDelegatePtr CreateGPUDelegate(FlatBufferModel* model);
#if defined(__ANDROID__) #if defined(__ANDROID__)
Interpreter::TfLiteDelegatePtr CreateGPUDelegate( Interpreter::TfLiteDelegatePtr CreateGPUDelegate(