Support delegate registrar (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/delegates) for the label_image example so that various delegate options supported by the TFLite benchmark and task evaluation tools could be reused here.
PiperOrigin-RevId: 334428572 Change-Id: Ib942b733ee4b0fe4a1a410bf38d9db1173dc00c9
This commit is contained in:
parent
7ed2ba3fa4
commit
905e0be75a
@ -29,16 +29,20 @@ cc_binary(
|
||||
}),
|
||||
deps = [
|
||||
":bitmap_helpers",
|
||||
"//tensorflow/lite/c:common",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite:string_util",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/delegates/nnapi:nnapi_delegate",
|
||||
"//tensorflow/lite/delegates/xnnpack:xnnpack_delegate",
|
||||
"//tensorflow/lite/kernels:builtin_ops",
|
||||
"//tensorflow/lite/profiling:profiler",
|
||||
"//tensorflow/lite/tools:command_line_flags",
|
||||
"//tensorflow/lite/tools:tool_params",
|
||||
"//tensorflow/lite/tools/delegates:delegate_provider_hdr",
|
||||
"//tensorflow/lite/tools/delegates:tflite_execution_providers",
|
||||
"//tensorflow/lite/tools/evaluation:utils",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
] + select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/lite/delegates/gpu:delegate",
|
||||
|
@ -222,3 +222,20 @@ label_image
|
||||
```
|
||||
|
||||
See the `label_image.cc` source code for other command line options.
|
||||
|
||||
Note that this binary also supports runtime/delegate arguments introduced by the
|
||||
[delegate registrar](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/delegates).
|
||||
If there is any conflict, the arguments mentioned earlier are given precedence.
|
||||
For example, you can run the binary with additional command line options
|
||||
such as `--use_nnapi=true --nnapi_accelerator_name=google-edgetpu` to utilize
|
||||
the EdgeTPU in a 4th-gen Pixel phone. Please be aware that the "=" in the option
|
||||
should not be omitted.
|
||||
|
||||
```
|
||||
adb shell \
|
||||
"/data/local/tmp/label_image \
|
||||
-m /data/local/tmp/mobilenet_v1_1.0_224_quant.tflite \
|
||||
-i /data/local/tmp/grace_hopper.bmp \
|
||||
-l /data/local/tmp/labels.txt -j 1 \
|
||||
--use_nnapi=true --nnapi_accelerator_name=google-edgetpu"
|
||||
```
|
||||
|
@ -44,6 +44,8 @@ limitations under the License.
|
||||
#include "tensorflow/lite/optional_debug_tools.h"
|
||||
#include "tensorflow/lite/profiling/profiler.h"
|
||||
#include "tensorflow/lite/string_util.h"
|
||||
#include "tensorflow/lite/tools/command_line_flags.h"
|
||||
#include "tensorflow/lite/tools/delegates/delegate_provider.h"
|
||||
#include "tensorflow/lite/tools/evaluation/utils.h"
|
||||
|
||||
#if defined(__ANDROID__)
|
||||
@ -60,6 +62,56 @@ double get_us(struct timeval t) { return (t.tv_sec * 1000000 + t.tv_usec); }
|
||||
using TfLiteDelegatePtr = tflite::Interpreter::TfLiteDelegatePtr;
|
||||
using TfLiteDelegatePtrMap = std::map<std::string, TfLiteDelegatePtr>;
|
||||
|
||||
class DelegateProviders {
|
||||
public:
|
||||
DelegateProviders()
|
||||
: delegates_list_(tflite::tools::GetRegisteredDelegateProviders()) {
|
||||
for (const auto& delegate : delegates_list_) {
|
||||
params_.Merge(delegate->DefaultParams());
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize delegate-related parameters from parsing command line arguments,
|
||||
// and remove the matching arguments from (*argc, argv). Returns true if all
|
||||
// recognized arg values are parsed correctly.
|
||||
bool InitFromCmdlineArgs(int* argc, const char** argv) {
|
||||
std::vector<tflite::Flag> flags;
|
||||
for (const auto& delegate : delegates_list_) {
|
||||
auto delegate_flags = delegate->CreateFlags(¶ms_);
|
||||
flags.insert(flags.end(), delegate_flags.begin(), delegate_flags.end());
|
||||
}
|
||||
|
||||
const bool parse_result = Flags::Parse(argc, argv, flags);
|
||||
if (!parse_result) {
|
||||
std::string usage = Flags::Usage(argv[0], flags);
|
||||
LOG(ERROR) << usage;
|
||||
}
|
||||
return parse_result;
|
||||
}
|
||||
|
||||
// Create a list of TfLite delegates based on what have been initialized (i.e.
|
||||
// 'params_').
|
||||
TfLiteDelegatePtrMap CreateAllDelegates() const {
|
||||
TfLiteDelegatePtrMap delegates_map;
|
||||
for (const auto& delegate : delegates_list_) {
|
||||
auto ptr = delegate->CreateTfLiteDelegate(params_);
|
||||
// It's possible that a delegate of certain type won't be created as
|
||||
// user-specified benchmark params tells not to.
|
||||
if (ptr == nullptr) continue;
|
||||
LOG(INFO) << delegate->GetName() << " delegate created.\n";
|
||||
delegates_map.emplace(delegate->GetName(), std::move(ptr));
|
||||
}
|
||||
return delegates_map;
|
||||
}
|
||||
|
||||
private:
|
||||
// Contain delegate-related parameters that are initialized from command-line
|
||||
// flags.
|
||||
tflite::tools::ToolParams params_;
|
||||
|
||||
const tflite::tools::DelegateProviderList& delegates_list_;
|
||||
};
|
||||
|
||||
TfLiteDelegatePtr CreateGPUDelegate(Settings* s) {
|
||||
#if defined(__ANDROID__)
|
||||
TfLiteGpuDelegateOptionsV2 gpu_opts = TfLiteGpuDelegateOptionsV2Default();
|
||||
@ -74,7 +126,10 @@ TfLiteDelegatePtr CreateGPUDelegate(Settings* s) {
|
||||
#endif
|
||||
}
|
||||
|
||||
TfLiteDelegatePtrMap GetDelegates(Settings* s) {
|
||||
TfLiteDelegatePtrMap GetDelegates(Settings* s,
|
||||
const DelegateProviders& delegate_providers) {
|
||||
// TODO(b/169681115): deprecate delegate creation path based on "Settings" by
|
||||
// mapping settings to DelegateProvider's parameters.
|
||||
TfLiteDelegatePtrMap delegates;
|
||||
if (s->gl_backend) {
|
||||
auto delegate = CreateGPUDelegate(s);
|
||||
@ -117,6 +172,18 @@ TfLiteDelegatePtrMap GetDelegates(Settings* s) {
|
||||
}
|
||||
}
|
||||
|
||||
// Independent of above delegate creation options that are specific to this
|
||||
// binary, we use delegate providers to create TFLite delegates. Delegate
|
||||
// providers have been used in TFLite benchmark/evaluation tools and testing
|
||||
// so that we have a single and more comprehensive set of command line
|
||||
// arguments for delegate creation.
|
||||
TfLiteDelegatePtrMap delegates_from_providers =
|
||||
delegate_providers.CreateAllDelegates();
|
||||
for (auto& name_and_delegate : delegates_from_providers) {
|
||||
delegates.emplace("Delegate_Provider_" + name_and_delegate.first,
|
||||
std::move(name_and_delegate.second));
|
||||
}
|
||||
|
||||
return delegates;
|
||||
}
|
||||
|
||||
@ -162,21 +229,22 @@ void PrintProfilingInfo(const profiling::ProfileEvent* e,
|
||||
<< "\n";
|
||||
}
|
||||
|
||||
void RunInference(Settings* s) {
|
||||
if (!s->model_name.c_str()) {
|
||||
void RunInference(Settings* settings,
|
||||
const DelegateProviders& delegate_providers) {
|
||||
if (!settings->model_name.c_str()) {
|
||||
LOG(ERROR) << "no model file name\n";
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
std::unique_ptr<tflite::FlatBufferModel> model;
|
||||
std::unique_ptr<tflite::Interpreter> interpreter;
|
||||
model = tflite::FlatBufferModel::BuildFromFile(s->model_name.c_str());
|
||||
model = tflite::FlatBufferModel::BuildFromFile(settings->model_name.c_str());
|
||||
if (!model) {
|
||||
LOG(ERROR) << "\nFailed to mmap model " << s->model_name << "\n";
|
||||
LOG(ERROR) << "\nFailed to mmap model " << settings->model_name << "\n";
|
||||
exit(-1);
|
||||
}
|
||||
s->model = model.get();
|
||||
LOG(INFO) << "Loaded model " << s->model_name << "\n";
|
||||
settings->model = model.get();
|
||||
LOG(INFO) << "Loaded model " << settings->model_name << "\n";
|
||||
model->error_reporter();
|
||||
LOG(INFO) << "resolved reporter\n";
|
||||
|
||||
@ -188,9 +256,9 @@ void RunInference(Settings* s) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
interpreter->SetAllowFp16PrecisionForFp32(s->allow_fp16);
|
||||
interpreter->SetAllowFp16PrecisionForFp32(settings->allow_fp16);
|
||||
|
||||
if (s->verbose) {
|
||||
if (settings->verbose) {
|
||||
LOG(INFO) << "tensors size: " << interpreter->tensors_size() << "\n";
|
||||
LOG(INFO) << "nodes size: " << interpreter->nodes_size() << "\n";
|
||||
LOG(INFO) << "inputs: " << interpreter->inputs().size() << "\n";
|
||||
@ -207,28 +275,28 @@ void RunInference(Settings* s) {
|
||||
}
|
||||
}
|
||||
|
||||
if (s->number_of_threads != -1) {
|
||||
interpreter->SetNumThreads(s->number_of_threads);
|
||||
if (settings->number_of_threads != -1) {
|
||||
interpreter->SetNumThreads(settings->number_of_threads);
|
||||
}
|
||||
|
||||
int image_width = 224;
|
||||
int image_height = 224;
|
||||
int image_channels = 3;
|
||||
std::vector<uint8_t> in = read_bmp(s->input_bmp_name, &image_width,
|
||||
&image_height, &image_channels, s);
|
||||
std::vector<uint8_t> in = read_bmp(settings->input_bmp_name, &image_width,
|
||||
&image_height, &image_channels, settings);
|
||||
|
||||
int input = interpreter->inputs()[0];
|
||||
if (s->verbose) LOG(INFO) << "input: " << input << "\n";
|
||||
if (settings->verbose) LOG(INFO) << "input: " << input << "\n";
|
||||
|
||||
const std::vector<int> inputs = interpreter->inputs();
|
||||
const std::vector<int> outputs = interpreter->outputs();
|
||||
|
||||
if (s->verbose) {
|
||||
if (settings->verbose) {
|
||||
LOG(INFO) << "number of inputs: " << inputs.size() << "\n";
|
||||
LOG(INFO) << "number of outputs: " << outputs.size() << "\n";
|
||||
}
|
||||
|
||||
auto delegates_ = GetDelegates(s);
|
||||
auto delegates_ = GetDelegates(settings, delegate_providers);
|
||||
for (const auto& delegate : delegates_) {
|
||||
if (interpreter->ModifyGraphWithDelegate(delegate.second.get()) !=
|
||||
kTfLiteOk) {
|
||||
@ -244,7 +312,7 @@ void RunInference(Settings* s) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
if (s->verbose) PrintInterpreterState(interpreter.get());
|
||||
if (settings->verbose) PrintInterpreterState(interpreter.get());
|
||||
|
||||
// get input dimension from the input tensor metadata
|
||||
// assuming one input only
|
||||
@ -253,44 +321,45 @@ void RunInference(Settings* s) {
|
||||
int wanted_width = dims->data[2];
|
||||
int wanted_channels = dims->data[3];
|
||||
|
||||
s->input_type = interpreter->tensor(input)->type;
|
||||
switch (s->input_type) {
|
||||
settings->input_type = interpreter->tensor(input)->type;
|
||||
switch (settings->input_type) {
|
||||
case kTfLiteFloat32:
|
||||
resize<float>(interpreter->typed_tensor<float>(input), in.data(),
|
||||
image_height, image_width, image_channels, wanted_height,
|
||||
wanted_width, wanted_channels, s);
|
||||
wanted_width, wanted_channels, settings);
|
||||
break;
|
||||
case kTfLiteInt8:
|
||||
resize<int8_t>(interpreter->typed_tensor<int8_t>(input), in.data(),
|
||||
image_height, image_width, image_channels, wanted_height,
|
||||
wanted_width, wanted_channels, s);
|
||||
wanted_width, wanted_channels, settings);
|
||||
break;
|
||||
case kTfLiteUInt8:
|
||||
resize<uint8_t>(interpreter->typed_tensor<uint8_t>(input), in.data(),
|
||||
image_height, image_width, image_channels, wanted_height,
|
||||
wanted_width, wanted_channels, s);
|
||||
wanted_width, wanted_channels, settings);
|
||||
break;
|
||||
default:
|
||||
LOG(ERROR) << "cannot handle input type "
|
||||
<< interpreter->tensor(input)->type << " yet\n";
|
||||
exit(-1);
|
||||
}
|
||||
auto profiler =
|
||||
absl::make_unique<profiling::Profiler>(s->max_profiling_buffer_entries);
|
||||
auto profiler = absl::make_unique<profiling::Profiler>(
|
||||
settings->max_profiling_buffer_entries);
|
||||
interpreter->SetProfiler(profiler.get());
|
||||
|
||||
if (s->profiling) profiler->StartProfiling();
|
||||
if (s->loop_count > 1)
|
||||
for (int i = 0; i < s->number_of_warmup_runs; i++) {
|
||||
if (settings->profiling) profiler->StartProfiling();
|
||||
if (settings->loop_count > 1) {
|
||||
for (int i = 0; i < settings->number_of_warmup_runs; i++) {
|
||||
if (interpreter->Invoke() != kTfLiteOk) {
|
||||
LOG(ERROR) << "Failed to invoke tflite!\n";
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct timeval start_time, stop_time;
|
||||
gettimeofday(&start_time, nullptr);
|
||||
for (int i = 0; i < s->loop_count; i++) {
|
||||
for (int i = 0; i < settings->loop_count; i++) {
|
||||
if (interpreter->Invoke() != kTfLiteOk) {
|
||||
LOG(ERROR) << "Failed to invoke tflite!\n";
|
||||
exit(-1);
|
||||
@ -299,10 +368,11 @@ void RunInference(Settings* s) {
|
||||
gettimeofday(&stop_time, nullptr);
|
||||
LOG(INFO) << "invoked\n";
|
||||
LOG(INFO) << "average time: "
|
||||
<< (get_us(stop_time) - get_us(start_time)) / (s->loop_count * 1000)
|
||||
<< (get_us(stop_time) - get_us(start_time)) /
|
||||
(settings->loop_count * 1000)
|
||||
<< " ms \n";
|
||||
|
||||
if (s->profiling) {
|
||||
if (settings->profiling) {
|
||||
profiler->StopProfiling();
|
||||
auto profile_events = profiler->GetProfileEvents();
|
||||
for (int i = 0; i < profile_events.size(); i++) {
|
||||
@ -328,18 +398,18 @@ void RunInference(Settings* s) {
|
||||
switch (interpreter->tensor(output)->type) {
|
||||
case kTfLiteFloat32:
|
||||
get_top_n<float>(interpreter->typed_output_tensor<float>(0), output_size,
|
||||
s->number_of_results, threshold, &top_results,
|
||||
s->input_type);
|
||||
settings->number_of_results, threshold, &top_results,
|
||||
settings->input_type);
|
||||
break;
|
||||
case kTfLiteInt8:
|
||||
get_top_n<int8_t>(interpreter->typed_output_tensor<int8_t>(0),
|
||||
output_size, s->number_of_results, threshold,
|
||||
&top_results, s->input_type);
|
||||
output_size, settings->number_of_results, threshold,
|
||||
&top_results, settings->input_type);
|
||||
break;
|
||||
case kTfLiteUInt8:
|
||||
get_top_n<uint8_t>(interpreter->typed_output_tensor<uint8_t>(0),
|
||||
output_size, s->number_of_results, threshold,
|
||||
&top_results, s->input_type);
|
||||
output_size, settings->number_of_results, threshold,
|
||||
&top_results, settings->input_type);
|
||||
break;
|
||||
default:
|
||||
LOG(ERROR) << "cannot handle output type "
|
||||
@ -350,7 +420,8 @@ void RunInference(Settings* s) {
|
||||
std::vector<string> labels;
|
||||
size_t label_count;
|
||||
|
||||
if (ReadLabelsFile(s->labels_file_name, &labels, &label_count) != kTfLiteOk)
|
||||
if (ReadLabelsFile(settings->labels_file_name, &labels, &label_count) !=
|
||||
kTfLiteOk)
|
||||
exit(-1);
|
||||
|
||||
for (const auto& result : top_results) {
|
||||
@ -383,6 +454,13 @@ void display_usage() {
|
||||
}
|
||||
|
||||
int Main(int argc, char** argv) {
|
||||
DelegateProviders delegate_providers;
|
||||
bool parse_result = delegate_providers.InitFromCmdlineArgs(
|
||||
&argc, const_cast<const char**>(argv));
|
||||
if (!parse_result) {
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
|
||||
Settings s;
|
||||
|
||||
int c;
|
||||
@ -488,7 +566,7 @@ int Main(int argc, char** argv) {
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
RunInference(&s);
|
||||
RunInference(&s, delegate_providers);
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user