Modify TF Lite benchmark to print names of available NNAPI accelerators when --use_nnapi=true

PiperOrigin-RevId: 266040746
This commit is contained in:
Tyler Davis 2019-08-28 18:33:39 -07:00 committed by TensorFlower Gardener
parent 386da9758d
commit a1a5f93073
8 changed files with 142 additions and 22 deletions

View File

@ -34,6 +34,7 @@ cc_library(
"//tensorflow/lite/c:c_api_internal",
"//tensorflow/lite/kernels:kernel_util",
"//tensorflow/lite/nnapi:nnapi_implementation",
"//tensorflow/lite/nnapi:nnapi_util",
],
)

View File

@ -51,6 +51,7 @@ limitations under the License.
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/minimal_logging.h"
#include "tensorflow/lite/nnapi/nnapi_implementation.h"
#include "tensorflow/lite/nnapi/nnapi_util.h"
#include "tensorflow/lite/util.h"
namespace tflite {
@ -297,21 +298,6 @@ static size_t getNumPaddingBytes(size_t byte_size) {
return num_padding_bytes;
}
std::string SimpleJoin(const std::vector<const char*>& elements,
const char* separator) {
// Note that we avoid use of sstream to avoid binary size bloat.
std::string joined_elements;
for (auto it = elements.begin(); it != elements.end(); ++it) {
if (separator && it != elements.begin()) {
joined_elements += separator;
}
if (*it) {
joined_elements += *it;
}
}
return joined_elements;
}
// Return NNAPI device handle with the provided null-terminated device name. If
// no matching device could be found, nullptr will be returned.
ANeuralNetworksDevice* GetDeviceHandle(TfLiteContext* context,
@ -322,7 +308,6 @@ ANeuralNetworksDevice* GetDeviceHandle(TfLiteContext* context,
uint32_t num_devices = 0;
NnApiImplementation()->ANeuralNetworks_getDeviceCount(&num_devices);
std::vector<const char*> device_names;
for (uint32_t i = 0; i < num_devices; i++) {
ANeuralNetworksDevice* device = nullptr;
const char* buffer = nullptr;
@ -332,14 +317,13 @@ ANeuralNetworksDevice* GetDeviceHandle(TfLiteContext* context,
device_handle = device;
break;
}
device_names.push_back(buffer);
}
if (!device_handle) {
context->ReportError(context,
"Could not find the specified NNAPI accelerator: %s. "
"Must be one of: {%s}.",
device_name_ptr,
SimpleJoin(device_names, ",").c_str());
nnapi::GetStringDeviceNamesList().c_str());
}
return device_handle;
}

View File

@ -61,6 +61,17 @@ cc_library(
],
)
cc_library(
name = "nnapi_util",
srcs = ["nnapi_util.cc"],
hdrs = ["nnapi_util.h"],
deps = [
":nnapi_implementation",
"//tensorflow/lite:util",
"//tensorflow/lite/c:c_api_internal",
],
)
cc_test(
name = "nnapi_implementation_test",
srcs = ["nnapi_implementation_test.cc"],

View File

@ -0,0 +1,73 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/nnapi/nnapi_util.h"
#include <string>
#include <vector>
#include "tensorflow/lite/nnapi/nnapi_implementation.h"
#include "tensorflow/lite/util.h"
namespace tflite {
namespace nnapi {
namespace {
std::string SimpleJoin(const std::vector<const char*>& elements,
const char* separator) {
// Note that we avoid use of sstream to avoid binary size bloat.
std::string joined_elements;
for (auto it = elements.begin(); it != elements.end(); ++it) {
if (separator && it != elements.begin()) {
joined_elements += separator;
}
if (*it) {
joined_elements += *it;
}
}
return joined_elements;
}
} // namespace
std::vector<const char*> GetDeviceNamesList() {
std::vector<const char*> device_names;
// Only build the list if NnApiImplementation has the methods we need,
// leaving it empty otherwise.
if (NnApiImplementation()->ANeuralNetworks_getDeviceCount != nullptr) {
uint32_t num_devices = 0;
NnApiImplementation()->ANeuralNetworks_getDeviceCount(&num_devices);
for (uint32_t i = 0; i < num_devices; i++) {
ANeuralNetworksDevice* device = nullptr;
const char* buffer = nullptr;
NnApiImplementation()->ANeuralNetworks_getDevice(i, &device);
NnApiImplementation()->ANeuralNetworksDevice_getName(device, &buffer);
device_names.push_back(buffer);
}
}
return device_names;
}
std::string GetStringDeviceNamesList() {
std::vector<const char*> device_names = GetDeviceNamesList();
return SimpleJoin(device_names, ",");
}
} // namespace nnapi
} // namespace tflite

View File

@ -0,0 +1,38 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file provides general C++ utility functions for interacting with NNAPI.
#ifndef TENSORFLOW_LITE_NNAPI_NNAPI_UTIL_H_
#define TENSORFLOW_LITE_NNAPI_NNAPI_UTIL_H_
#include <string>
#include <vector>
namespace tflite {
namespace nnapi {
// Return std::vector consisting of pointers to null-terminated device names.
// These names are guaranteed valid for the lifetime of the application.
std::vector<const char*> GetDeviceNamesList();
// Return a string containing the names of all available devices.
// Will take the format: "DeviceA,DeviceB,DeviceC"
std::string GetStringDeviceNamesList();
} // namespace nnapi
} // namespace tflite
#endif // TENSORFLOW_LITE_NNAPI_NNAPI_UTIL_H_

View File

@ -112,6 +112,7 @@ cc_library(
"//tensorflow/lite:framework",
"//tensorflow/lite:string_util",
"//tensorflow/lite/kernels:builtin_ops",
"//tensorflow/lite/nnapi:nnapi_util",
"//tensorflow/lite/profiling:profile_summarizer",
"//tensorflow/lite/profiling:profiler",
"//tensorflow/lite/tools/evaluation:utils",

View File

@ -38,7 +38,9 @@ and the following optional parameters:
Whether to use [Android NNAPI](https://developer.android.com/ndk/guides/neuralnetworks/).
This API is available on recent Android devices. Note that some Android P
devices will fail to use NNAPI for models in `/data/local/tmp/` and this
benchmark tool will not correctly use NNAPI.
benchmark tool will not correctly use NNAPI. When on Android Q+, will also
print the names of NNAPI accelerators accessible through the
`nnapi_accelerator_name` flag.
* `nnapi_accelerator_name`: `str` (default="") \
The name of the NNAPI accelerator to use (requires Android Q+). If left
blank, NNAPI will automatically select which of the available accelerators

View File

@ -23,6 +23,8 @@ limitations under the License.
#include <unordered_set>
#include <vector>
#include "tensorflow/lite/nnapi/nnapi_util.h"
#if defined(__ANDROID__)
#include "tensorflow/lite/delegates/gpu/gl_delegate.h"
#endif
@ -302,9 +304,17 @@ void BenchmarkTfLiteModel::LogParams() {
}
TFLITE_LOG(INFO) << "Use legacy nnapi : ["
<< params_.Get<bool>("use_legacy_nnapi") << "]";
if (!params_.Get<std::string>("nnapi_accelerator_name").empty()) {
TFLITE_LOG(INFO) << "nnapi accelerator name: ["
<< params_.Get<string>("nnapi_accelerator_name") << "]";
if (params_.Get<bool>("use_nnapi")) {
std::string log_string = "nnapi accelerator name: [" +
params_.Get<string>("nnapi_accelerator_name") +
"]";
std::string string_device_names_list = nnapi::GetStringDeviceNamesList();
// Print available devices when possible
if (!string_device_names_list.empty()) {
log_string += " (Available: " + string_device_names_list + ")";
}
TFLITE_LOG(INFO) << log_string;
}
TFLITE_LOG(INFO) << "Use gpu : [" << params_.Get<bool>("use_gpu") << "]";
#if defined(__ANDROID__)