Support delegate registrar (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/delegates) for the label_image example so that various delegate options supported by the TFLite benchmark and task evaluation tools could be reused here.

PiperOrigin-RevId: 334428572
Change-Id: Ib942b733ee4b0fe4a1a410bf38d9db1173dc00c9
This commit is contained in:
Chao Mei 2020-09-29 11:50:15 -07:00 committed by TensorFlower Gardener
parent 7ed2ba3fa4
commit 905e0be75a
3 changed files with 140 additions and 41 deletions

View File

@ -29,16 +29,20 @@ cc_binary(
}),
deps = [
":bitmap_helpers",
"//tensorflow/lite/c:common",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"//tensorflow/lite:framework",
"//tensorflow/lite:string_util",
"//tensorflow/lite/c:common",
"//tensorflow/lite/delegates/nnapi:nnapi_delegate",
"//tensorflow/lite/delegates/xnnpack:xnnpack_delegate",
"//tensorflow/lite/kernels:builtin_ops",
"//tensorflow/lite/profiling:profiler",
"//tensorflow/lite/tools:command_line_flags",
"//tensorflow/lite/tools:tool_params",
"//tensorflow/lite/tools/delegates:delegate_provider_hdr",
"//tensorflow/lite/tools/delegates:tflite_execution_providers",
"//tensorflow/lite/tools/evaluation:utils",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
] + select({
"//tensorflow:android": [
"//tensorflow/lite/delegates/gpu:delegate",

View File

@ -222,3 +222,20 @@ label_image
```
See the `label_image.cc` source code for other command line options.
Note that this binary also supports runtime/delegate arguments introduced by the
[delegate registrar](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/delegates).
If there is any conflict, the arguments mentioned earlier are given precedence.
For example, you can run the binary with additional command line options
such as `--use_nnapi=true --nnapi_accelerator_name=google-edgetpu` to utilize
the EdgeTPU in a 4th-gen Pixel phone. Please be aware that the "=" in the option
should not be omitted.
```
adb shell \
"/data/local/tmp/label_image \
-m /data/local/tmp/mobilenet_v1_1.0_224_quant.tflite \
-i /data/local/tmp/grace_hopper.bmp \
-l /data/local/tmp/labels.txt -j 1 \
--use_nnapi=true --nnapi_accelerator_name=google-edgetpu"
```

View File

@ -44,6 +44,8 @@ limitations under the License.
#include "tensorflow/lite/optional_debug_tools.h"
#include "tensorflow/lite/profiling/profiler.h"
#include "tensorflow/lite/string_util.h"
#include "tensorflow/lite/tools/command_line_flags.h"
#include "tensorflow/lite/tools/delegates/delegate_provider.h"
#include "tensorflow/lite/tools/evaluation/utils.h"
#if defined(__ANDROID__)
@ -60,6 +62,56 @@ double get_us(struct timeval t) { return (t.tv_sec * 1000000 + t.tv_usec); }
using TfLiteDelegatePtr = tflite::Interpreter::TfLiteDelegatePtr;
using TfLiteDelegatePtrMap = std::map<std::string, TfLiteDelegatePtr>;
class DelegateProviders {
public:
DelegateProviders()
: delegates_list_(tflite::tools::GetRegisteredDelegateProviders()) {
for (const auto& delegate : delegates_list_) {
params_.Merge(delegate->DefaultParams());
}
}
// Initialize delegate-related parameters from parsing command line arguments,
// and remove the matching arguments from (*argc, argv). Returns true if all
// recognized arg values are parsed correctly.
bool InitFromCmdlineArgs(int* argc, const char** argv) {
std::vector<tflite::Flag> flags;
for (const auto& delegate : delegates_list_) {
auto delegate_flags = delegate->CreateFlags(&params_);
flags.insert(flags.end(), delegate_flags.begin(), delegate_flags.end());
}
const bool parse_result = Flags::Parse(argc, argv, flags);
if (!parse_result) {
std::string usage = Flags::Usage(argv[0], flags);
LOG(ERROR) << usage;
}
return parse_result;
}
// Create a list of TfLite delegates based on what have been initialized (i.e.
// 'params_').
TfLiteDelegatePtrMap CreateAllDelegates() const {
TfLiteDelegatePtrMap delegates_map;
for (const auto& delegate : delegates_list_) {
auto ptr = delegate->CreateTfLiteDelegate(params_);
// It's possible that a delegate of certain type won't be created as
// user-specified benchmark params tells not to.
if (ptr == nullptr) continue;
LOG(INFO) << delegate->GetName() << " delegate created.\n";
delegates_map.emplace(delegate->GetName(), std::move(ptr));
}
return delegates_map;
}
private:
// Contain delegate-related parameters that are initialized from command-line
// flags.
tflite::tools::ToolParams params_;
const tflite::tools::DelegateProviderList& delegates_list_;
};
TfLiteDelegatePtr CreateGPUDelegate(Settings* s) {
#if defined(__ANDROID__)
TfLiteGpuDelegateOptionsV2 gpu_opts = TfLiteGpuDelegateOptionsV2Default();
@ -74,7 +126,10 @@ TfLiteDelegatePtr CreateGPUDelegate(Settings* s) {
#endif
}
TfLiteDelegatePtrMap GetDelegates(Settings* s) {
TfLiteDelegatePtrMap GetDelegates(Settings* s,
const DelegateProviders& delegate_providers) {
// TODO(b/169681115): deprecate delegate creation path based on "Settings" by
// mapping settings to DelegateProvider's parameters.
TfLiteDelegatePtrMap delegates;
if (s->gl_backend) {
auto delegate = CreateGPUDelegate(s);
@ -117,6 +172,18 @@ TfLiteDelegatePtrMap GetDelegates(Settings* s) {
}
}
// Independent of above delegate creation options that are specific to this
// binary, we use delegate providers to create TFLite delegates. Delegate
// providers have been used in TFLite benchmark/evaluation tools and testing
// so that we have a single and more comprehensive set of command line
// arguments for delegate creation.
TfLiteDelegatePtrMap delegates_from_providers =
delegate_providers.CreateAllDelegates();
for (auto& name_and_delegate : delegates_from_providers) {
delegates.emplace("Delegate_Provider_" + name_and_delegate.first,
std::move(name_and_delegate.second));
}
return delegates;
}
@ -162,21 +229,22 @@ void PrintProfilingInfo(const profiling::ProfileEvent* e,
<< "\n";
}
void RunInference(Settings* s) {
if (!s->model_name.c_str()) {
void RunInference(Settings* settings,
const DelegateProviders& delegate_providers) {
if (!settings->model_name.c_str()) {
LOG(ERROR) << "no model file name\n";
exit(-1);
}
std::unique_ptr<tflite::FlatBufferModel> model;
std::unique_ptr<tflite::Interpreter> interpreter;
model = tflite::FlatBufferModel::BuildFromFile(s->model_name.c_str());
model = tflite::FlatBufferModel::BuildFromFile(settings->model_name.c_str());
if (!model) {
LOG(ERROR) << "\nFailed to mmap model " << s->model_name << "\n";
LOG(ERROR) << "\nFailed to mmap model " << settings->model_name << "\n";
exit(-1);
}
s->model = model.get();
LOG(INFO) << "Loaded model " << s->model_name << "\n";
settings->model = model.get();
LOG(INFO) << "Loaded model " << settings->model_name << "\n";
model->error_reporter();
LOG(INFO) << "resolved reporter\n";
@ -188,9 +256,9 @@ void RunInference(Settings* s) {
exit(-1);
}
interpreter->SetAllowFp16PrecisionForFp32(s->allow_fp16);
interpreter->SetAllowFp16PrecisionForFp32(settings->allow_fp16);
if (s->verbose) {
if (settings->verbose) {
LOG(INFO) << "tensors size: " << interpreter->tensors_size() << "\n";
LOG(INFO) << "nodes size: " << interpreter->nodes_size() << "\n";
LOG(INFO) << "inputs: " << interpreter->inputs().size() << "\n";
@ -207,28 +275,28 @@ void RunInference(Settings* s) {
}
}
if (s->number_of_threads != -1) {
interpreter->SetNumThreads(s->number_of_threads);
if (settings->number_of_threads != -1) {
interpreter->SetNumThreads(settings->number_of_threads);
}
int image_width = 224;
int image_height = 224;
int image_channels = 3;
std::vector<uint8_t> in = read_bmp(s->input_bmp_name, &image_width,
&image_height, &image_channels, s);
std::vector<uint8_t> in = read_bmp(settings->input_bmp_name, &image_width,
&image_height, &image_channels, settings);
int input = interpreter->inputs()[0];
if (s->verbose) LOG(INFO) << "input: " << input << "\n";
if (settings->verbose) LOG(INFO) << "input: " << input << "\n";
const std::vector<int> inputs = interpreter->inputs();
const std::vector<int> outputs = interpreter->outputs();
if (s->verbose) {
if (settings->verbose) {
LOG(INFO) << "number of inputs: " << inputs.size() << "\n";
LOG(INFO) << "number of outputs: " << outputs.size() << "\n";
}
auto delegates_ = GetDelegates(s);
auto delegates_ = GetDelegates(settings, delegate_providers);
for (const auto& delegate : delegates_) {
if (interpreter->ModifyGraphWithDelegate(delegate.second.get()) !=
kTfLiteOk) {
@ -244,7 +312,7 @@ void RunInference(Settings* s) {
exit(-1);
}
if (s->verbose) PrintInterpreterState(interpreter.get());
if (settings->verbose) PrintInterpreterState(interpreter.get());
// get input dimension from the input tensor metadata
// assuming one input only
@ -253,44 +321,45 @@ void RunInference(Settings* s) {
int wanted_width = dims->data[2];
int wanted_channels = dims->data[3];
s->input_type = interpreter->tensor(input)->type;
switch (s->input_type) {
settings->input_type = interpreter->tensor(input)->type;
switch (settings->input_type) {
case kTfLiteFloat32:
resize<float>(interpreter->typed_tensor<float>(input), in.data(),
image_height, image_width, image_channels, wanted_height,
wanted_width, wanted_channels, s);
wanted_width, wanted_channels, settings);
break;
case kTfLiteInt8:
resize<int8_t>(interpreter->typed_tensor<int8_t>(input), in.data(),
image_height, image_width, image_channels, wanted_height,
wanted_width, wanted_channels, s);
wanted_width, wanted_channels, settings);
break;
case kTfLiteUInt8:
resize<uint8_t>(interpreter->typed_tensor<uint8_t>(input), in.data(),
image_height, image_width, image_channels, wanted_height,
wanted_width, wanted_channels, s);
wanted_width, wanted_channels, settings);
break;
default:
LOG(ERROR) << "cannot handle input type "
<< interpreter->tensor(input)->type << " yet\n";
exit(-1);
}
auto profiler =
absl::make_unique<profiling::Profiler>(s->max_profiling_buffer_entries);
auto profiler = absl::make_unique<profiling::Profiler>(
settings->max_profiling_buffer_entries);
interpreter->SetProfiler(profiler.get());
if (s->profiling) profiler->StartProfiling();
if (s->loop_count > 1)
for (int i = 0; i < s->number_of_warmup_runs; i++) {
if (settings->profiling) profiler->StartProfiling();
if (settings->loop_count > 1) {
for (int i = 0; i < settings->number_of_warmup_runs; i++) {
if (interpreter->Invoke() != kTfLiteOk) {
LOG(ERROR) << "Failed to invoke tflite!\n";
exit(-1);
}
}
}
struct timeval start_time, stop_time;
gettimeofday(&start_time, nullptr);
for (int i = 0; i < s->loop_count; i++) {
for (int i = 0; i < settings->loop_count; i++) {
if (interpreter->Invoke() != kTfLiteOk) {
LOG(ERROR) << "Failed to invoke tflite!\n";
exit(-1);
@ -299,10 +368,11 @@ void RunInference(Settings* s) {
gettimeofday(&stop_time, nullptr);
LOG(INFO) << "invoked\n";
LOG(INFO) << "average time: "
<< (get_us(stop_time) - get_us(start_time)) / (s->loop_count * 1000)
<< (get_us(stop_time) - get_us(start_time)) /
(settings->loop_count * 1000)
<< " ms \n";
if (s->profiling) {
if (settings->profiling) {
profiler->StopProfiling();
auto profile_events = profiler->GetProfileEvents();
for (int i = 0; i < profile_events.size(); i++) {
@ -328,18 +398,18 @@ void RunInference(Settings* s) {
switch (interpreter->tensor(output)->type) {
case kTfLiteFloat32:
get_top_n<float>(interpreter->typed_output_tensor<float>(0), output_size,
s->number_of_results, threshold, &top_results,
s->input_type);
settings->number_of_results, threshold, &top_results,
settings->input_type);
break;
case kTfLiteInt8:
get_top_n<int8_t>(interpreter->typed_output_tensor<int8_t>(0),
output_size, s->number_of_results, threshold,
&top_results, s->input_type);
output_size, settings->number_of_results, threshold,
&top_results, settings->input_type);
break;
case kTfLiteUInt8:
get_top_n<uint8_t>(interpreter->typed_output_tensor<uint8_t>(0),
output_size, s->number_of_results, threshold,
&top_results, s->input_type);
output_size, settings->number_of_results, threshold,
&top_results, settings->input_type);
break;
default:
LOG(ERROR) << "cannot handle output type "
@ -350,7 +420,8 @@ void RunInference(Settings* s) {
std::vector<string> labels;
size_t label_count;
if (ReadLabelsFile(s->labels_file_name, &labels, &label_count) != kTfLiteOk)
if (ReadLabelsFile(settings->labels_file_name, &labels, &label_count) !=
kTfLiteOk)
exit(-1);
for (const auto& result : top_results) {
@ -383,6 +454,13 @@ void display_usage() {
}
int Main(int argc, char** argv) {
DelegateProviders delegate_providers;
bool parse_result = delegate_providers.InitFromCmdlineArgs(
&argc, const_cast<const char**>(argv));
if (!parse_result) {
return EXIT_FAILURE;
}
Settings s;
int c;
@ -488,7 +566,7 @@ int Main(int argc, char** argv) {
exit(-1);
}
}
RunInference(&s);
RunInference(&s, delegate_providers);
return 0;
}