diff --git a/tensorflow/lite/tools/evaluation/evaluation_delegate_provider.cc b/tensorflow/lite/tools/evaluation/evaluation_delegate_provider.cc index 42f2666ba9b..fc40440b105 100644 --- a/tensorflow/lite/tools/evaluation/evaluation_delegate_provider.cc +++ b/tensorflow/lite/tools/evaluation/evaluation_delegate_provider.cc @@ -97,7 +97,13 @@ bool DelegateProviders::InitFromCmdlineArgs(int* argc, const char** argv) { auto one_flags = one->CreateFlags(¶ms_); flags.insert(flags.end(), one_flags.begin(), one_flags.end()); } - return Flags::Parse(argc, argv, flags); + + const bool parse_result = Flags::Parse(argc, argv, flags); + if (!parse_result) { + std::string usage = Flags::Usage(argv[0], flags); + TFLITE_LOG(ERROR) << usage; + } + return parse_result; } TfLiteDelegatePtr DelegateProviders::CreateDelegate( diff --git a/tensorflow/lite/tools/evaluation/tasks/BUILD b/tensorflow/lite/tools/evaluation/tasks/BUILD index d8daf170331..5272542f045 100644 --- a/tensorflow/lite/tools/evaluation/tasks/BUILD +++ b/tensorflow/lite/tools/evaluation/tasks/BUILD @@ -10,10 +10,14 @@ package( cc_library( name = "task_executor", + srcs = ["task_executor.cc"], hdrs = ["task_executor.h"], copts = tflite_copts(), linkopts = task_linkopts(), deps = [ + "//tensorflow/lite/tools:command_line_flags", + "//tensorflow/lite/tools:logging", + "//tensorflow/lite/tools/evaluation:evaluation_delegate_provider", "//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto", "@com_google_absl//absl/types:optional", ], diff --git a/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/BUILD b/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/BUILD index b8f77d72acb..dc5f8237f6a 100644 --- a/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/BUILD +++ b/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/BUILD @@ -26,7 +26,6 @@ cc_library( "//tensorflow/lite/c:common", "//tensorflow/lite/tools:command_line_flags", "//tensorflow/lite/tools:logging", - "//tensorflow/lite/tools/evaluation:evaluation_delegate_provider", "//tensorflow/lite/tools/evaluation:evaluation_stage", "//tensorflow/lite/tools/evaluation:utils", "//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto", diff --git a/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/run_eval.cc b/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/run_eval.cc index 765e8fc6465..73491457f38 100644 --- a/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/run_eval.cc +++ b/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/run_eval.cc @@ -21,7 +21,6 @@ limitations under the License. #include "absl/types/optional.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/tools/command_line_flags.h" -#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h" #include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h" #include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h" #include "tensorflow/lite/tools/evaluation/stages/object_detection_stage.h" @@ -49,11 +48,14 @@ std::string GetNameFromPath(const std::string& str) { class CocoObjectDetection : public TaskExecutor { public: - CocoObjectDetection(int* argc, char* argv[]); + CocoObjectDetection() : debug_mode_(false), num_interpreter_threads_(1) {} ~CocoObjectDetection() override {} + protected: + std::vector GetFlags() final; + // If the run is successful, the latest metrics will be returned. - absl::optional Run() final; + absl::optional RunImpl() final; private: void OutputResult(const EvaluationStageMetrics& latest_metrics) const; @@ -68,8 +70,7 @@ class CocoObjectDetection : public TaskExecutor { DelegateProviders delegate_providers_; }; -CocoObjectDetection::CocoObjectDetection(int* argc, char* argv[]) - : debug_mode_(false), num_interpreter_threads_(1) { +std::vector CocoObjectDetection::GetFlags() { std::vector flag_list = { tflite::Flag::CreateFlag(kModelFileFlag, &model_file_path_, "Path to test tflite model file."), @@ -105,12 +106,10 @@ CocoObjectDetection::CocoObjectDetection(int* argc, char* argv[]) "Delegate to use for inference, if available. " "Must be one of {'nnapi', 'gpu', 'xnnpack', 'hexagon'}"), }; - tflite::Flags::Parse(argc, const_cast(argv), flag_list); - DelegateProviders delegate_providers; - delegate_providers.InitFromCmdlineArgs(argc, const_cast(argv)); + return flag_list; } -absl::optional CocoObjectDetection::Run() { +absl::optional CocoObjectDetection::RunImpl() { // Process images in filename-sorted order. std::vector image_paths; if (GetSortedFileNames(StripTrailingSlashes(ground_truth_images_path_), @@ -224,8 +223,8 @@ void CocoObjectDetection::OutputResult( << precision_metrics.overall_mean_average_precision(); } -std::unique_ptr CreateTaskExecutor(int* argc, char* argv[]) { - return std::unique_ptr(new CocoObjectDetection(argc, argv)); +std::unique_ptr CreateTaskExecutor() { + return std::unique_ptr(new CocoObjectDetection()); } } // namespace evaluation diff --git a/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/BUILD b/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/BUILD index de2a7f96311..941bbc0ff69 100644 --- a/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/BUILD +++ b/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/BUILD @@ -17,7 +17,6 @@ cc_library( "//tensorflow/lite/c:common", "//tensorflow/lite/tools:command_line_flags", "//tensorflow/lite/tools:logging", - "//tensorflow/lite/tools/evaluation:evaluation_delegate_provider", "//tensorflow/lite/tools/evaluation:evaluation_stage", "//tensorflow/lite/tools/evaluation:utils", "//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto", diff --git a/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/run_eval.cc b/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/run_eval.cc index 13eeb313ad4..fdc97d44abc 100644 --- a/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/run_eval.cc +++ b/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/run_eval.cc @@ -20,7 +20,6 @@ limitations under the License. #include "absl/types/optional.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/tools/command_line_flags.h" -#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h" #include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h" #include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h" #include "tensorflow/lite/tools/evaluation/stages/image_classification_stage.h" @@ -50,11 +49,14 @@ std::vector GetFirstN(const std::vector& v, int n) { class ImagenetClassification : public TaskExecutor { public: - ImagenetClassification(int* argc, char* argv[]); + ImagenetClassification() : num_images_(0), num_interpreter_threads_(1) {} ~ImagenetClassification() override {} + protected: + std::vector GetFlags() final; + // If the run is successful, the latest metrics will be returned. - absl::optional Run() final; + absl::optional RunImpl() final; private: void OutputResult(const EvaluationStageMetrics& latest_metrics) const; @@ -67,11 +69,9 @@ class ImagenetClassification : public TaskExecutor { std::string delegate_; int num_images_; int num_interpreter_threads_; - DelegateProviders delegate_providers_; }; -ImagenetClassification::ImagenetClassification(int* argc, char* argv[]) - : num_images_(0), num_interpreter_threads_(1) { +std::vector ImagenetClassification::GetFlags() { std::vector flag_list = { tflite::Flag::CreateFlag(kModelFileFlag, &model_file_path_, "Path to test tflite model file."), @@ -107,11 +107,10 @@ ImagenetClassification::ImagenetClassification(int* argc, char* argv[]) "Delegate to use for inference, if available. " "Must be one of {'nnapi', 'gpu', 'hexagon', 'xnnpack'}"), }; - tflite::Flags::Parse(argc, const_cast(argv), flag_list); - delegate_providers_.InitFromCmdlineArgs(argc, const_cast(argv)); + return flag_list; } -absl::optional ImagenetClassification::Run() { +absl::optional ImagenetClassification::RunImpl() { // Process images in filename-sorted order. std::vector image_files, ground_truth_image_labels; if (GetSortedFileNames(StripTrailingSlashes(ground_truth_images_path_), @@ -203,8 +202,8 @@ void ImagenetClassification::OutputResult( } } -std::unique_ptr CreateTaskExecutor(int* argc, char* argv[]) { - return std::unique_ptr(new ImagenetClassification(argc, argv)); +std::unique_ptr CreateTaskExecutor() { + return std::unique_ptr(new ImagenetClassification()); } } // namespace evaluation diff --git a/tensorflow/lite/tools/evaluation/tasks/inference_diff/BUILD b/tensorflow/lite/tools/evaluation/tasks/inference_diff/BUILD index a53872b50cb..36606722caf 100644 --- a/tensorflow/lite/tools/evaluation/tasks/inference_diff/BUILD +++ b/tensorflow/lite/tools/evaluation/tasks/inference_diff/BUILD @@ -17,7 +17,6 @@ cc_library( "//tensorflow/lite/c:common", "//tensorflow/lite/tools:command_line_flags", "//tensorflow/lite/tools:logging", - "//tensorflow/lite/tools/evaluation:evaluation_delegate_provider", "//tensorflow/lite/tools/evaluation:evaluation_stage", "//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto", "//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto", diff --git a/tensorflow/lite/tools/evaluation/tasks/inference_diff/run_eval.cc b/tensorflow/lite/tools/evaluation/tasks/inference_diff/run_eval.cc index 814ebe3b3bf..9a3fea0b8a3 100644 --- a/tensorflow/lite/tools/evaluation/tasks/inference_diff/run_eval.cc +++ b/tensorflow/lite/tools/evaluation/tasks/inference_diff/run_eval.cc @@ -19,7 +19,6 @@ limitations under the License. #include "absl/types/optional.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/tools/command_line_flags.h" -#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h" #include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h" #include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h" #include "tensorflow/lite/tools/evaluation/stages/inference_profiler_stage.h" @@ -37,11 +36,14 @@ constexpr char kDelegateFlag[] = "delegate"; class InferenceDiff : public TaskExecutor { public: - InferenceDiff(int* argc, char* argv[]); + InferenceDiff() : num_runs_(50), num_interpreter_threads_(1) {} ~InferenceDiff() override {} + protected: + std::vector GetFlags() final; + // If the run is successful, the latest metrics will be returned. - absl::optional Run() final; + absl::optional RunImpl() final; private: void OutputResult(const EvaluationStageMetrics& latest_metrics) const; @@ -50,11 +52,9 @@ class InferenceDiff : public TaskExecutor { std::string delegate_; int num_runs_; int num_interpreter_threads_; - DelegateProviders delegate_providers_; }; -InferenceDiff::InferenceDiff(int* argc, char* argv[]) - : num_runs_(50), num_interpreter_threads_(1) { +std::vector InferenceDiff::GetFlags() { // Command Line Flags. std::vector flag_list = { tflite::Flag::CreateFlag(kModelFileFlag, &model_file_path_, @@ -72,11 +72,11 @@ InferenceDiff::InferenceDiff(int* argc, char* argv[]) "Delegate to use for test inference, if available. " "Must be one of {'nnapi', 'gpu', 'hexagon', 'xnnpack'}"), }; - tflite::Flags::Parse(argc, const_cast(argv), flag_list); - delegate_providers_.InitFromCmdlineArgs(argc, const_cast(argv)); + + return flag_list; } -absl::optional InferenceDiff::Run() { +absl::optional InferenceDiff::RunImpl() { // Initialize evaluation stage. EvaluationStageConfig eval_config; eval_config.set_name("inference_profiling"); @@ -137,8 +137,8 @@ void InferenceDiff::OutputResult( } } -std::unique_ptr CreateTaskExecutor(int* argc, char* argv[]) { - return std::unique_ptr(new InferenceDiff(argc, argv)); +std::unique_ptr CreateTaskExecutor() { + return std::unique_ptr(new InferenceDiff()); } } // namespace evaluation diff --git a/tensorflow/lite/tools/evaluation/tasks/task_executor.cc b/tensorflow/lite/tools/evaluation/tasks/task_executor.cc new file mode 100644 index 00000000000..e62793dc6ff --- /dev/null +++ b/tensorflow/lite/tools/evaluation/tasks/task_executor.cc @@ -0,0 +1,47 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/tools/evaluation/tasks/task_executor.h" + +#include "absl/types/optional.h" +#include "tensorflow/lite/tools/logging.h" + +namespace tflite { +namespace evaluation { +absl::optional TaskExecutor::Run(int* argc, + char* argv[]) { + auto flag_list = GetFlags(); + bool parse_result = + tflite::Flags::Parse(argc, const_cast(argv), flag_list); + if (!parse_result) { + std::string usage = Flags::Usage(argv[0], flag_list); + TFLITE_LOG(ERROR) << usage; + return absl::nullopt; + } + parse_result = delegate_providers_.InitFromCmdlineArgs( + argc, const_cast(argv)); + if (!parse_result) { + return absl::nullopt; + } + + std::string unconsumed_args = + Flags::ArgsToString(*argc, const_cast(argv)); + if (!unconsumed_args.empty()) { + TFLITE_LOG(WARN) << "Unconsumed cmdline flags: " << unconsumed_args; + } + + return RunImpl(); +} +} // namespace evaluation +} // namespace tflite diff --git a/tensorflow/lite/tools/evaluation/tasks/task_executor.h b/tensorflow/lite/tools/evaluation/tasks/task_executor.h index b50e7d6d03f..caa84283098 100644 --- a/tensorflow/lite/tools/evaluation/tasks/task_executor.h +++ b/tensorflow/lite/tools/evaluation/tasks/task_executor.h @@ -16,6 +16,8 @@ limitations under the License. #define TENSORFLOW_LITE_TOOLS_EVALUATION_TASKS_TASK_EXECUTOR_H_ #include "absl/types/optional.h" +#include "tensorflow/lite/tools/command_line_flags.h" +#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h" #include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h" namespace tflite { @@ -25,13 +27,22 @@ namespace evaluation { class TaskExecutor { public: virtual ~TaskExecutor() {} + // If the run is successful, the latest metrics will be returned. - virtual absl::optional Run() = 0; + absl::optional Run(int* argc, char* argv[]); + + protected: + // Returns a list of commandline flags that this task defines. + virtual std::vector GetFlags() = 0; + + virtual absl::optional RunImpl() = 0; + + DelegateProviders delegate_providers_; }; // Just a declaration. In order to avoid the boilerpolate main-function code, // every evaluation task should define this function. -std::unique_ptr CreateTaskExecutor(int* argc, char* argv[]); +std::unique_ptr CreateTaskExecutor(); } // namespace evaluation } // namespace tflite diff --git a/tensorflow/lite/tools/evaluation/tasks/task_executor_main.cc b/tensorflow/lite/tools/evaluation/tasks/task_executor_main.cc index 6ef1a6724b7..97f8e263659 100644 --- a/tensorflow/lite/tools/evaluation/tasks/task_executor_main.cc +++ b/tensorflow/lite/tools/evaluation/tasks/task_executor_main.cc @@ -18,12 +18,12 @@ limitations under the License. // This could serve as the main function for all eval tools. int main(int argc, char* argv[]) { - auto task_executor = tflite::evaluation::CreateTaskExecutor(&argc, argv); + auto task_executor = tflite::evaluation::CreateTaskExecutor(); if (task_executor == nullptr) { TFLITE_LOG(ERROR) << "Could not create the task evaluation!"; return EXIT_FAILURE; } - const auto metrics = task_executor->Run(); + const auto metrics = task_executor->Run(&argc, argv); if (!metrics.has_value()) { TFLITE_LOG(ERROR) << "Could not run the task evaluation!"; return EXIT_FAILURE;