Support to output unconsumed flags and exit the execution if cmdline flags fail to be parsed for tflite evaluation tools.
PiperOrigin-RevId: 317989024 Change-Id: I52cc2249246b7d19c9c8a257ac1478d48f7de8fa
This commit is contained in:
parent
ef20289d99
commit
a4f7dd5436
@ -97,7 +97,13 @@ bool DelegateProviders::InitFromCmdlineArgs(int* argc, const char** argv) {
|
||||
auto one_flags = one->CreateFlags(¶ms_);
|
||||
flags.insert(flags.end(), one_flags.begin(), one_flags.end());
|
||||
}
|
||||
return Flags::Parse(argc, argv, flags);
|
||||
|
||||
const bool parse_result = Flags::Parse(argc, argv, flags);
|
||||
if (!parse_result) {
|
||||
std::string usage = Flags::Usage(argv[0], flags);
|
||||
TFLITE_LOG(ERROR) << usage;
|
||||
}
|
||||
return parse_result;
|
||||
}
|
||||
|
||||
TfLiteDelegatePtr DelegateProviders::CreateDelegate(
|
||||
|
@ -10,10 +10,14 @@ package(
|
||||
|
||||
cc_library(
|
||||
name = "task_executor",
|
||||
srcs = ["task_executor.cc"],
|
||||
hdrs = ["task_executor.h"],
|
||||
copts = tflite_copts(),
|
||||
linkopts = task_linkopts(),
|
||||
deps = [
|
||||
"//tensorflow/lite/tools:command_line_flags",
|
||||
"//tensorflow/lite/tools:logging",
|
||||
"//tensorflow/lite/tools/evaluation:evaluation_delegate_provider",
|
||||
"//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
|
@ -26,7 +26,6 @@ cc_library(
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/tools:command_line_flags",
|
||||
"//tensorflow/lite/tools:logging",
|
||||
"//tensorflow/lite/tools/evaluation:evaluation_delegate_provider",
|
||||
"//tensorflow/lite/tools/evaluation:evaluation_stage",
|
||||
"//tensorflow/lite/tools/evaluation:utils",
|
||||
"//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto",
|
||||
|
@ -21,7 +21,6 @@ limitations under the License.
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/tools/command_line_flags.h"
|
||||
#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h"
|
||||
#include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
|
||||
#include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
|
||||
#include "tensorflow/lite/tools/evaluation/stages/object_detection_stage.h"
|
||||
@ -49,11 +48,14 @@ std::string GetNameFromPath(const std::string& str) {
|
||||
|
||||
class CocoObjectDetection : public TaskExecutor {
|
||||
public:
|
||||
CocoObjectDetection(int* argc, char* argv[]);
|
||||
CocoObjectDetection() : debug_mode_(false), num_interpreter_threads_(1) {}
|
||||
~CocoObjectDetection() override {}
|
||||
|
||||
protected:
|
||||
std::vector<Flag> GetFlags() final;
|
||||
|
||||
// If the run is successful, the latest metrics will be returned.
|
||||
absl::optional<EvaluationStageMetrics> Run() final;
|
||||
absl::optional<EvaluationStageMetrics> RunImpl() final;
|
||||
|
||||
private:
|
||||
void OutputResult(const EvaluationStageMetrics& latest_metrics) const;
|
||||
@ -68,8 +70,7 @@ class CocoObjectDetection : public TaskExecutor {
|
||||
DelegateProviders delegate_providers_;
|
||||
};
|
||||
|
||||
CocoObjectDetection::CocoObjectDetection(int* argc, char* argv[])
|
||||
: debug_mode_(false), num_interpreter_threads_(1) {
|
||||
std::vector<Flag> CocoObjectDetection::GetFlags() {
|
||||
std::vector<tflite::Flag> flag_list = {
|
||||
tflite::Flag::CreateFlag(kModelFileFlag, &model_file_path_,
|
||||
"Path to test tflite model file."),
|
||||
@ -105,12 +106,10 @@ CocoObjectDetection::CocoObjectDetection(int* argc, char* argv[])
|
||||
"Delegate to use for inference, if available. "
|
||||
"Must be one of {'nnapi', 'gpu', 'xnnpack', 'hexagon'}"),
|
||||
};
|
||||
tflite::Flags::Parse(argc, const_cast<const char**>(argv), flag_list);
|
||||
DelegateProviders delegate_providers;
|
||||
delegate_providers.InitFromCmdlineArgs(argc, const_cast<const char**>(argv));
|
||||
return flag_list;
|
||||
}
|
||||
|
||||
absl::optional<EvaluationStageMetrics> CocoObjectDetection::Run() {
|
||||
absl::optional<EvaluationStageMetrics> CocoObjectDetection::RunImpl() {
|
||||
// Process images in filename-sorted order.
|
||||
std::vector<std::string> image_paths;
|
||||
if (GetSortedFileNames(StripTrailingSlashes(ground_truth_images_path_),
|
||||
@ -224,8 +223,8 @@ void CocoObjectDetection::OutputResult(
|
||||
<< precision_metrics.overall_mean_average_precision();
|
||||
}
|
||||
|
||||
std::unique_ptr<TaskExecutor> CreateTaskExecutor(int* argc, char* argv[]) {
|
||||
return std::unique_ptr<TaskExecutor>(new CocoObjectDetection(argc, argv));
|
||||
std::unique_ptr<TaskExecutor> CreateTaskExecutor() {
|
||||
return std::unique_ptr<TaskExecutor>(new CocoObjectDetection());
|
||||
}
|
||||
|
||||
} // namespace evaluation
|
||||
|
@ -17,7 +17,6 @@ cc_library(
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/tools:command_line_flags",
|
||||
"//tensorflow/lite/tools:logging",
|
||||
"//tensorflow/lite/tools/evaluation:evaluation_delegate_provider",
|
||||
"//tensorflow/lite/tools/evaluation:evaluation_stage",
|
||||
"//tensorflow/lite/tools/evaluation:utils",
|
||||
"//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto",
|
||||
|
@ -20,7 +20,6 @@ limitations under the License.
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/tools/command_line_flags.h"
|
||||
#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h"
|
||||
#include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
|
||||
#include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
|
||||
#include "tensorflow/lite/tools/evaluation/stages/image_classification_stage.h"
|
||||
@ -50,11 +49,14 @@ std::vector<T> GetFirstN(const std::vector<T>& v, int n) {
|
||||
|
||||
class ImagenetClassification : public TaskExecutor {
|
||||
public:
|
||||
ImagenetClassification(int* argc, char* argv[]);
|
||||
ImagenetClassification() : num_images_(0), num_interpreter_threads_(1) {}
|
||||
~ImagenetClassification() override {}
|
||||
|
||||
protected:
|
||||
std::vector<Flag> GetFlags() final;
|
||||
|
||||
// If the run is successful, the latest metrics will be returned.
|
||||
absl::optional<EvaluationStageMetrics> Run() final;
|
||||
absl::optional<EvaluationStageMetrics> RunImpl() final;
|
||||
|
||||
private:
|
||||
void OutputResult(const EvaluationStageMetrics& latest_metrics) const;
|
||||
@ -67,11 +69,9 @@ class ImagenetClassification : public TaskExecutor {
|
||||
std::string delegate_;
|
||||
int num_images_;
|
||||
int num_interpreter_threads_;
|
||||
DelegateProviders delegate_providers_;
|
||||
};
|
||||
|
||||
ImagenetClassification::ImagenetClassification(int* argc, char* argv[])
|
||||
: num_images_(0), num_interpreter_threads_(1) {
|
||||
std::vector<Flag> ImagenetClassification::GetFlags() {
|
||||
std::vector<tflite::Flag> flag_list = {
|
||||
tflite::Flag::CreateFlag(kModelFileFlag, &model_file_path_,
|
||||
"Path to test tflite model file."),
|
||||
@ -107,11 +107,10 @@ ImagenetClassification::ImagenetClassification(int* argc, char* argv[])
|
||||
"Delegate to use for inference, if available. "
|
||||
"Must be one of {'nnapi', 'gpu', 'hexagon', 'xnnpack'}"),
|
||||
};
|
||||
tflite::Flags::Parse(argc, const_cast<const char**>(argv), flag_list);
|
||||
delegate_providers_.InitFromCmdlineArgs(argc, const_cast<const char**>(argv));
|
||||
return flag_list;
|
||||
}
|
||||
|
||||
absl::optional<EvaluationStageMetrics> ImagenetClassification::Run() {
|
||||
absl::optional<EvaluationStageMetrics> ImagenetClassification::RunImpl() {
|
||||
// Process images in filename-sorted order.
|
||||
std::vector<std::string> image_files, ground_truth_image_labels;
|
||||
if (GetSortedFileNames(StripTrailingSlashes(ground_truth_images_path_),
|
||||
@ -203,8 +202,8 @@ void ImagenetClassification::OutputResult(
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<TaskExecutor> CreateTaskExecutor(int* argc, char* argv[]) {
|
||||
return std::unique_ptr<TaskExecutor>(new ImagenetClassification(argc, argv));
|
||||
std::unique_ptr<TaskExecutor> CreateTaskExecutor() {
|
||||
return std::unique_ptr<TaskExecutor>(new ImagenetClassification());
|
||||
}
|
||||
|
||||
} // namespace evaluation
|
||||
|
@ -17,7 +17,6 @@ cc_library(
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/tools:command_line_flags",
|
||||
"//tensorflow/lite/tools:logging",
|
||||
"//tensorflow/lite/tools/evaluation:evaluation_delegate_provider",
|
||||
"//tensorflow/lite/tools/evaluation:evaluation_stage",
|
||||
"//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto",
|
||||
"//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto",
|
||||
|
@ -19,7 +19,6 @@ limitations under the License.
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/tools/command_line_flags.h"
|
||||
#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h"
|
||||
#include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
|
||||
#include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
|
||||
#include "tensorflow/lite/tools/evaluation/stages/inference_profiler_stage.h"
|
||||
@ -37,11 +36,14 @@ constexpr char kDelegateFlag[] = "delegate";
|
||||
|
||||
class InferenceDiff : public TaskExecutor {
|
||||
public:
|
||||
InferenceDiff(int* argc, char* argv[]);
|
||||
InferenceDiff() : num_runs_(50), num_interpreter_threads_(1) {}
|
||||
~InferenceDiff() override {}
|
||||
|
||||
protected:
|
||||
std::vector<Flag> GetFlags() final;
|
||||
|
||||
// If the run is successful, the latest metrics will be returned.
|
||||
absl::optional<EvaluationStageMetrics> Run() final;
|
||||
absl::optional<EvaluationStageMetrics> RunImpl() final;
|
||||
|
||||
private:
|
||||
void OutputResult(const EvaluationStageMetrics& latest_metrics) const;
|
||||
@ -50,11 +52,9 @@ class InferenceDiff : public TaskExecutor {
|
||||
std::string delegate_;
|
||||
int num_runs_;
|
||||
int num_interpreter_threads_;
|
||||
DelegateProviders delegate_providers_;
|
||||
};
|
||||
|
||||
InferenceDiff::InferenceDiff(int* argc, char* argv[])
|
||||
: num_runs_(50), num_interpreter_threads_(1) {
|
||||
std::vector<Flag> InferenceDiff::GetFlags() {
|
||||
// Command Line Flags.
|
||||
std::vector<tflite::Flag> flag_list = {
|
||||
tflite::Flag::CreateFlag(kModelFileFlag, &model_file_path_,
|
||||
@ -72,11 +72,11 @@ InferenceDiff::InferenceDiff(int* argc, char* argv[])
|
||||
"Delegate to use for test inference, if available. "
|
||||
"Must be one of {'nnapi', 'gpu', 'hexagon', 'xnnpack'}"),
|
||||
};
|
||||
tflite::Flags::Parse(argc, const_cast<const char**>(argv), flag_list);
|
||||
delegate_providers_.InitFromCmdlineArgs(argc, const_cast<const char**>(argv));
|
||||
|
||||
return flag_list;
|
||||
}
|
||||
|
||||
absl::optional<EvaluationStageMetrics> InferenceDiff::Run() {
|
||||
absl::optional<EvaluationStageMetrics> InferenceDiff::RunImpl() {
|
||||
// Initialize evaluation stage.
|
||||
EvaluationStageConfig eval_config;
|
||||
eval_config.set_name("inference_profiling");
|
||||
@ -137,8 +137,8 @@ void InferenceDiff::OutputResult(
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<TaskExecutor> CreateTaskExecutor(int* argc, char* argv[]) {
|
||||
return std::unique_ptr<TaskExecutor>(new InferenceDiff(argc, argv));
|
||||
std::unique_ptr<TaskExecutor> CreateTaskExecutor() {
|
||||
return std::unique_ptr<TaskExecutor>(new InferenceDiff());
|
||||
}
|
||||
|
||||
} // namespace evaluation
|
||||
|
47
tensorflow/lite/tools/evaluation/tasks/task_executor.cc
Normal file
47
tensorflow/lite/tools/evaluation/tasks/task_executor.cc
Normal file
@ -0,0 +1,47 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/tools/evaluation/tasks/task_executor.h"
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/lite/tools/logging.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace evaluation {
|
||||
absl::optional<EvaluationStageMetrics> TaskExecutor::Run(int* argc,
|
||||
char* argv[]) {
|
||||
auto flag_list = GetFlags();
|
||||
bool parse_result =
|
||||
tflite::Flags::Parse(argc, const_cast<const char**>(argv), flag_list);
|
||||
if (!parse_result) {
|
||||
std::string usage = Flags::Usage(argv[0], flag_list);
|
||||
TFLITE_LOG(ERROR) << usage;
|
||||
return absl::nullopt;
|
||||
}
|
||||
parse_result = delegate_providers_.InitFromCmdlineArgs(
|
||||
argc, const_cast<const char**>(argv));
|
||||
if (!parse_result) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
std::string unconsumed_args =
|
||||
Flags::ArgsToString(*argc, const_cast<const char**>(argv));
|
||||
if (!unconsumed_args.empty()) {
|
||||
TFLITE_LOG(WARN) << "Unconsumed cmdline flags: " << unconsumed_args;
|
||||
}
|
||||
|
||||
return RunImpl();
|
||||
}
|
||||
} // namespace evaluation
|
||||
} // namespace tflite
|
@ -16,6 +16,8 @@ limitations under the License.
|
||||
#define TENSORFLOW_LITE_TOOLS_EVALUATION_TASKS_TASK_EXECUTOR_H_
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/lite/tools/command_line_flags.h"
|
||||
#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h"
|
||||
#include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
|
||||
|
||||
namespace tflite {
|
||||
@ -25,13 +27,22 @@ namespace evaluation {
|
||||
class TaskExecutor {
|
||||
public:
|
||||
virtual ~TaskExecutor() {}
|
||||
|
||||
// If the run is successful, the latest metrics will be returned.
|
||||
virtual absl::optional<EvaluationStageMetrics> Run() = 0;
|
||||
absl::optional<EvaluationStageMetrics> Run(int* argc, char* argv[]);
|
||||
|
||||
protected:
|
||||
// Returns a list of commandline flags that this task defines.
|
||||
virtual std::vector<Flag> GetFlags() = 0;
|
||||
|
||||
virtual absl::optional<EvaluationStageMetrics> RunImpl() = 0;
|
||||
|
||||
DelegateProviders delegate_providers_;
|
||||
};
|
||||
|
||||
// Just a declaration. In order to avoid the boilerpolate main-function code,
|
||||
// every evaluation task should define this function.
|
||||
std::unique_ptr<TaskExecutor> CreateTaskExecutor(int* argc, char* argv[]);
|
||||
std::unique_ptr<TaskExecutor> CreateTaskExecutor();
|
||||
} // namespace evaluation
|
||||
} // namespace tflite
|
||||
|
||||
|
@ -18,12 +18,12 @@ limitations under the License.
|
||||
|
||||
// This could serve as the main function for all eval tools.
|
||||
int main(int argc, char* argv[]) {
|
||||
auto task_executor = tflite::evaluation::CreateTaskExecutor(&argc, argv);
|
||||
auto task_executor = tflite::evaluation::CreateTaskExecutor();
|
||||
if (task_executor == nullptr) {
|
||||
TFLITE_LOG(ERROR) << "Could not create the task evaluation!";
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
const auto metrics = task_executor->Run();
|
||||
const auto metrics = task_executor->Run(&argc, argv);
|
||||
if (!metrics.has_value()) {
|
||||
TFLITE_LOG(ERROR) << "Could not run the task evaluation!";
|
||||
return EXIT_FAILURE;
|
||||
|
Loading…
Reference in New Issue
Block a user