Support to output unconsumed flags and exit the execution if cmdline flags fail to be parsed for tflite evaluation tools.

PiperOrigin-RevId: 317989024
Change-Id: I52cc2249246b7d19c9c8a257ac1478d48f7de8fa
This commit is contained in:
Chao Mei 2020-06-23 20:04:29 -07:00 committed by TensorFlower Gardener
parent ef20289d99
commit a4f7dd5436
11 changed files with 104 additions and 41 deletions

View File

@ -97,7 +97,13 @@ bool DelegateProviders::InitFromCmdlineArgs(int* argc, const char** argv) {
auto one_flags = one->CreateFlags(&params_);
flags.insert(flags.end(), one_flags.begin(), one_flags.end());
}
return Flags::Parse(argc, argv, flags);
const bool parse_result = Flags::Parse(argc, argv, flags);
if (!parse_result) {
std::string usage = Flags::Usage(argv[0], flags);
TFLITE_LOG(ERROR) << usage;
}
return parse_result;
}
TfLiteDelegatePtr DelegateProviders::CreateDelegate(

View File

@ -10,10 +10,14 @@ package(
cc_library(
name = "task_executor",
srcs = ["task_executor.cc"],
hdrs = ["task_executor.h"],
copts = tflite_copts(),
linkopts = task_linkopts(),
deps = [
"//tensorflow/lite/tools:command_line_flags",
"//tensorflow/lite/tools:logging",
"//tensorflow/lite/tools/evaluation:evaluation_delegate_provider",
"//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto",
"@com_google_absl//absl/types:optional",
],

View File

@ -26,7 +26,6 @@ cc_library(
"//tensorflow/lite/c:common",
"//tensorflow/lite/tools:command_line_flags",
"//tensorflow/lite/tools:logging",
"//tensorflow/lite/tools/evaluation:evaluation_delegate_provider",
"//tensorflow/lite/tools/evaluation:evaluation_stage",
"//tensorflow/lite/tools/evaluation:utils",
"//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto",

View File

@ -21,7 +21,6 @@ limitations under the License.
#include "absl/types/optional.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/tools/command_line_flags.h"
#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h"
#include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
#include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
#include "tensorflow/lite/tools/evaluation/stages/object_detection_stage.h"
@ -49,11 +48,14 @@ std::string GetNameFromPath(const std::string& str) {
class CocoObjectDetection : public TaskExecutor {
public:
CocoObjectDetection(int* argc, char* argv[]);
CocoObjectDetection() : debug_mode_(false), num_interpreter_threads_(1) {}
~CocoObjectDetection() override {}
protected:
std::vector<Flag> GetFlags() final;
// If the run is successful, the latest metrics will be returned.
absl::optional<EvaluationStageMetrics> Run() final;
absl::optional<EvaluationStageMetrics> RunImpl() final;
private:
void OutputResult(const EvaluationStageMetrics& latest_metrics) const;
@ -68,8 +70,7 @@ class CocoObjectDetection : public TaskExecutor {
DelegateProviders delegate_providers_;
};
CocoObjectDetection::CocoObjectDetection(int* argc, char* argv[])
: debug_mode_(false), num_interpreter_threads_(1) {
std::vector<Flag> CocoObjectDetection::GetFlags() {
std::vector<tflite::Flag> flag_list = {
tflite::Flag::CreateFlag(kModelFileFlag, &model_file_path_,
"Path to test tflite model file."),
@ -105,12 +106,10 @@ CocoObjectDetection::CocoObjectDetection(int* argc, char* argv[])
"Delegate to use for inference, if available. "
"Must be one of {'nnapi', 'gpu', 'xnnpack', 'hexagon'}"),
};
tflite::Flags::Parse(argc, const_cast<const char**>(argv), flag_list);
DelegateProviders delegate_providers;
delegate_providers.InitFromCmdlineArgs(argc, const_cast<const char**>(argv));
return flag_list;
}
absl::optional<EvaluationStageMetrics> CocoObjectDetection::Run() {
absl::optional<EvaluationStageMetrics> CocoObjectDetection::RunImpl() {
// Process images in filename-sorted order.
std::vector<std::string> image_paths;
if (GetSortedFileNames(StripTrailingSlashes(ground_truth_images_path_),
@ -224,8 +223,8 @@ void CocoObjectDetection::OutputResult(
<< precision_metrics.overall_mean_average_precision();
}
std::unique_ptr<TaskExecutor> CreateTaskExecutor(int* argc, char* argv[]) {
return std::unique_ptr<TaskExecutor>(new CocoObjectDetection(argc, argv));
std::unique_ptr<TaskExecutor> CreateTaskExecutor() {
return std::unique_ptr<TaskExecutor>(new CocoObjectDetection());
}
} // namespace evaluation

View File

@ -17,7 +17,6 @@ cc_library(
"//tensorflow/lite/c:common",
"//tensorflow/lite/tools:command_line_flags",
"//tensorflow/lite/tools:logging",
"//tensorflow/lite/tools/evaluation:evaluation_delegate_provider",
"//tensorflow/lite/tools/evaluation:evaluation_stage",
"//tensorflow/lite/tools/evaluation:utils",
"//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto",

View File

@ -20,7 +20,6 @@ limitations under the License.
#include "absl/types/optional.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/tools/command_line_flags.h"
#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h"
#include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
#include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
#include "tensorflow/lite/tools/evaluation/stages/image_classification_stage.h"
@ -50,11 +49,14 @@ std::vector<T> GetFirstN(const std::vector<T>& v, int n) {
class ImagenetClassification : public TaskExecutor {
public:
ImagenetClassification(int* argc, char* argv[]);
ImagenetClassification() : num_images_(0), num_interpreter_threads_(1) {}
~ImagenetClassification() override {}
protected:
std::vector<Flag> GetFlags() final;
// If the run is successful, the latest metrics will be returned.
absl::optional<EvaluationStageMetrics> Run() final;
absl::optional<EvaluationStageMetrics> RunImpl() final;
private:
void OutputResult(const EvaluationStageMetrics& latest_metrics) const;
@ -67,11 +69,9 @@ class ImagenetClassification : public TaskExecutor {
std::string delegate_;
int num_images_;
int num_interpreter_threads_;
DelegateProviders delegate_providers_;
};
ImagenetClassification::ImagenetClassification(int* argc, char* argv[])
: num_images_(0), num_interpreter_threads_(1) {
std::vector<Flag> ImagenetClassification::GetFlags() {
std::vector<tflite::Flag> flag_list = {
tflite::Flag::CreateFlag(kModelFileFlag, &model_file_path_,
"Path to test tflite model file."),
@ -107,11 +107,10 @@ ImagenetClassification::ImagenetClassification(int* argc, char* argv[])
"Delegate to use for inference, if available. "
"Must be one of {'nnapi', 'gpu', 'hexagon', 'xnnpack'}"),
};
tflite::Flags::Parse(argc, const_cast<const char**>(argv), flag_list);
delegate_providers_.InitFromCmdlineArgs(argc, const_cast<const char**>(argv));
return flag_list;
}
absl::optional<EvaluationStageMetrics> ImagenetClassification::Run() {
absl::optional<EvaluationStageMetrics> ImagenetClassification::RunImpl() {
// Process images in filename-sorted order.
std::vector<std::string> image_files, ground_truth_image_labels;
if (GetSortedFileNames(StripTrailingSlashes(ground_truth_images_path_),
@ -203,8 +202,8 @@ void ImagenetClassification::OutputResult(
}
}
std::unique_ptr<TaskExecutor> CreateTaskExecutor(int* argc, char* argv[]) {
return std::unique_ptr<TaskExecutor>(new ImagenetClassification(argc, argv));
std::unique_ptr<TaskExecutor> CreateTaskExecutor() {
return std::unique_ptr<TaskExecutor>(new ImagenetClassification());
}
} // namespace evaluation

View File

@ -17,7 +17,6 @@ cc_library(
"//tensorflow/lite/c:common",
"//tensorflow/lite/tools:command_line_flags",
"//tensorflow/lite/tools:logging",
"//tensorflow/lite/tools/evaluation:evaluation_delegate_provider",
"//tensorflow/lite/tools/evaluation:evaluation_stage",
"//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto",
"//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto",

View File

@ -19,7 +19,6 @@ limitations under the License.
#include "absl/types/optional.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/tools/command_line_flags.h"
#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h"
#include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
#include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
#include "tensorflow/lite/tools/evaluation/stages/inference_profiler_stage.h"
@ -37,11 +36,14 @@ constexpr char kDelegateFlag[] = "delegate";
class InferenceDiff : public TaskExecutor {
public:
InferenceDiff(int* argc, char* argv[]);
InferenceDiff() : num_runs_(50), num_interpreter_threads_(1) {}
~InferenceDiff() override {}
protected:
std::vector<Flag> GetFlags() final;
// If the run is successful, the latest metrics will be returned.
absl::optional<EvaluationStageMetrics> Run() final;
absl::optional<EvaluationStageMetrics> RunImpl() final;
private:
void OutputResult(const EvaluationStageMetrics& latest_metrics) const;
@ -50,11 +52,9 @@ class InferenceDiff : public TaskExecutor {
std::string delegate_;
int num_runs_;
int num_interpreter_threads_;
DelegateProviders delegate_providers_;
};
InferenceDiff::InferenceDiff(int* argc, char* argv[])
: num_runs_(50), num_interpreter_threads_(1) {
std::vector<Flag> InferenceDiff::GetFlags() {
// Command Line Flags.
std::vector<tflite::Flag> flag_list = {
tflite::Flag::CreateFlag(kModelFileFlag, &model_file_path_,
@ -72,11 +72,11 @@ InferenceDiff::InferenceDiff(int* argc, char* argv[])
"Delegate to use for test inference, if available. "
"Must be one of {'nnapi', 'gpu', 'hexagon', 'xnnpack'}"),
};
tflite::Flags::Parse(argc, const_cast<const char**>(argv), flag_list);
delegate_providers_.InitFromCmdlineArgs(argc, const_cast<const char**>(argv));
return flag_list;
}
absl::optional<EvaluationStageMetrics> InferenceDiff::Run() {
absl::optional<EvaluationStageMetrics> InferenceDiff::RunImpl() {
// Initialize evaluation stage.
EvaluationStageConfig eval_config;
eval_config.set_name("inference_profiling");
@ -137,8 +137,8 @@ void InferenceDiff::OutputResult(
}
}
std::unique_ptr<TaskExecutor> CreateTaskExecutor(int* argc, char* argv[]) {
return std::unique_ptr<TaskExecutor>(new InferenceDiff(argc, argv));
std::unique_ptr<TaskExecutor> CreateTaskExecutor() {
return std::unique_ptr<TaskExecutor>(new InferenceDiff());
}
} // namespace evaluation

View File

@ -0,0 +1,47 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/tools/evaluation/tasks/task_executor.h"
#include "absl/types/optional.h"
#include "tensorflow/lite/tools/logging.h"
namespace tflite {
namespace evaluation {
absl::optional<EvaluationStageMetrics> TaskExecutor::Run(int* argc,
char* argv[]) {
auto flag_list = GetFlags();
bool parse_result =
tflite::Flags::Parse(argc, const_cast<const char**>(argv), flag_list);
if (!parse_result) {
std::string usage = Flags::Usage(argv[0], flag_list);
TFLITE_LOG(ERROR) << usage;
return absl::nullopt;
}
parse_result = delegate_providers_.InitFromCmdlineArgs(
argc, const_cast<const char**>(argv));
if (!parse_result) {
return absl::nullopt;
}
std::string unconsumed_args =
Flags::ArgsToString(*argc, const_cast<const char**>(argv));
if (!unconsumed_args.empty()) {
TFLITE_LOG(WARN) << "Unconsumed cmdline flags: " << unconsumed_args;
}
return RunImpl();
}
} // namespace evaluation
} // namespace tflite

View File

@ -16,6 +16,8 @@ limitations under the License.
#define TENSORFLOW_LITE_TOOLS_EVALUATION_TASKS_TASK_EXECUTOR_H_
#include "absl/types/optional.h"
#include "tensorflow/lite/tools/command_line_flags.h"
#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h"
#include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
namespace tflite {
@ -25,13 +27,22 @@ namespace evaluation {
class TaskExecutor {
public:
virtual ~TaskExecutor() {}
// If the run is successful, the latest metrics will be returned.
virtual absl::optional<EvaluationStageMetrics> Run() = 0;
absl::optional<EvaluationStageMetrics> Run(int* argc, char* argv[]);
protected:
// Returns a list of commandline flags that this task defines.
virtual std::vector<Flag> GetFlags() = 0;
virtual absl::optional<EvaluationStageMetrics> RunImpl() = 0;
DelegateProviders delegate_providers_;
};
// Just a declaration. In order to avoid the boilerpolate main-function code,
// every evaluation task should define this function.
std::unique_ptr<TaskExecutor> CreateTaskExecutor(int* argc, char* argv[]);
std::unique_ptr<TaskExecutor> CreateTaskExecutor();
} // namespace evaluation
} // namespace tflite

View File

@ -18,12 +18,12 @@ limitations under the License.
// This could serve as the main function for all eval tools.
int main(int argc, char* argv[]) {
auto task_executor = tflite::evaluation::CreateTaskExecutor(&argc, argv);
auto task_executor = tflite::evaluation::CreateTaskExecutor();
if (task_executor == nullptr) {
TFLITE_LOG(ERROR) << "Could not create the task evaluation!";
return EXIT_FAILURE;
}
const auto metrics = task_executor->Run();
const auto metrics = task_executor->Run(&argc, argv);
if (!metrics.has_value()) {
TFLITE_LOG(ERROR) << "Could not run the task evaluation!";
return EXIT_FAILURE;