From a4f7dd5436885a8ecdc6ac34bc4689e7e04ed2af Mon Sep 17 00:00:00 2001
From: Chao Mei <chaomei@google.com>
Date: Tue, 23 Jun 2020 20:04:29 -0700
Subject: [PATCH] Support to output unconsumed flags and exit the execution if
 cmdline flags fail to be parsed for tflite evaluation tools.

PiperOrigin-RevId: 317989024
Change-Id: I52cc2249246b7d19c9c8a257ac1478d48f7de8fa
---
 .../evaluation_delegate_provider.cc           |  8 +++-
 tensorflow/lite/tools/evaluation/tasks/BUILD  |  4 ++
 .../tasks/coco_object_detection/BUILD         |  1 -
 .../tasks/coco_object_detection/run_eval.cc   | 21 ++++-----
 .../tasks/imagenet_image_classification/BUILD |  1 -
 .../imagenet_image_classification/run_eval.cc | 21 ++++-----
 .../evaluation/tasks/inference_diff/BUILD     |  1 -
 .../tasks/inference_diff/run_eval.cc          | 22 ++++-----
 .../tools/evaluation/tasks/task_executor.cc   | 47 +++++++++++++++++++
 .../tools/evaluation/tasks/task_executor.h    | 15 +++++-
 .../evaluation/tasks/task_executor_main.cc    |  4 +-
 11 files changed, 104 insertions(+), 41 deletions(-)
 create mode 100644 tensorflow/lite/tools/evaluation/tasks/task_executor.cc

diff --git a/tensorflow/lite/tools/evaluation/evaluation_delegate_provider.cc b/tensorflow/lite/tools/evaluation/evaluation_delegate_provider.cc
index 42f2666ba9b..fc40440b105 100644
--- a/tensorflow/lite/tools/evaluation/evaluation_delegate_provider.cc
+++ b/tensorflow/lite/tools/evaluation/evaluation_delegate_provider.cc
@@ -97,7 +97,13 @@ bool DelegateProviders::InitFromCmdlineArgs(int* argc, const char** argv) {
     auto one_flags = one->CreateFlags(&params_);
     flags.insert(flags.end(), one_flags.begin(), one_flags.end());
   }
-  return Flags::Parse(argc, argv, flags);
+
+  const bool parse_result = Flags::Parse(argc, argv, flags);
+  if (!parse_result) {
+    std::string usage = Flags::Usage(argv[0], flags);
+    TFLITE_LOG(ERROR) << usage;
+  }
+  return parse_result;
 }
 
 TfLiteDelegatePtr DelegateProviders::CreateDelegate(
diff --git a/tensorflow/lite/tools/evaluation/tasks/BUILD b/tensorflow/lite/tools/evaluation/tasks/BUILD
index d8daf170331..5272542f045 100644
--- a/tensorflow/lite/tools/evaluation/tasks/BUILD
+++ b/tensorflow/lite/tools/evaluation/tasks/BUILD
@@ -10,10 +10,14 @@ package(
 
 cc_library(
     name = "task_executor",
+    srcs = ["task_executor.cc"],
     hdrs = ["task_executor.h"],
     copts = tflite_copts(),
     linkopts = task_linkopts(),
     deps = [
+        "//tensorflow/lite/tools:command_line_flags",
+        "//tensorflow/lite/tools:logging",
+        "//tensorflow/lite/tools/evaluation:evaluation_delegate_provider",
         "//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto",
         "@com_google_absl//absl/types:optional",
     ],
diff --git a/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/BUILD b/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/BUILD
index b8f77d72acb..dc5f8237f6a 100644
--- a/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/BUILD
+++ b/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/BUILD
@@ -26,7 +26,6 @@ cc_library(
         "//tensorflow/lite/c:common",
         "//tensorflow/lite/tools:command_line_flags",
         "//tensorflow/lite/tools:logging",
-        "//tensorflow/lite/tools/evaluation:evaluation_delegate_provider",
         "//tensorflow/lite/tools/evaluation:evaluation_stage",
         "//tensorflow/lite/tools/evaluation:utils",
         "//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto",
diff --git a/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/run_eval.cc b/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/run_eval.cc
index 765e8fc6465..73491457f38 100644
--- a/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/run_eval.cc
+++ b/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/run_eval.cc
@@ -21,7 +21,6 @@ limitations under the License.
 #include "absl/types/optional.h"
 #include "tensorflow/lite/c/common.h"
 #include "tensorflow/lite/tools/command_line_flags.h"
-#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h"
 #include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
 #include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
 #include "tensorflow/lite/tools/evaluation/stages/object_detection_stage.h"
@@ -49,11 +48,14 @@ std::string GetNameFromPath(const std::string& str) {
 
 class CocoObjectDetection : public TaskExecutor {
  public:
-  CocoObjectDetection(int* argc, char* argv[]);
+  CocoObjectDetection() : debug_mode_(false), num_interpreter_threads_(1) {}
   ~CocoObjectDetection() override {}
 
+ protected:
+  std::vector<Flag> GetFlags() final;
+
   // If the run is successful, the latest metrics will be returned.
-  absl::optional<EvaluationStageMetrics> Run() final;
+  absl::optional<EvaluationStageMetrics> RunImpl() final;
 
  private:
   void OutputResult(const EvaluationStageMetrics& latest_metrics) const;
@@ -68,8 +70,7 @@ class CocoObjectDetection : public TaskExecutor {
   DelegateProviders delegate_providers_;
 };
 
-CocoObjectDetection::CocoObjectDetection(int* argc, char* argv[])
-    : debug_mode_(false), num_interpreter_threads_(1) {
+std::vector<Flag> CocoObjectDetection::GetFlags() {
   std::vector<tflite::Flag> flag_list = {
       tflite::Flag::CreateFlag(kModelFileFlag, &model_file_path_,
                                "Path to test tflite model file."),
@@ -105,12 +106,10 @@ CocoObjectDetection::CocoObjectDetection(int* argc, char* argv[])
           "Delegate to use for inference, if available. "
           "Must be one of {'nnapi', 'gpu', 'xnnpack', 'hexagon'}"),
   };
-  tflite::Flags::Parse(argc, const_cast<const char**>(argv), flag_list);
-  DelegateProviders delegate_providers;
-  delegate_providers.InitFromCmdlineArgs(argc, const_cast<const char**>(argv));
+  return flag_list;
 }
 
-absl::optional<EvaluationStageMetrics> CocoObjectDetection::Run() {
+absl::optional<EvaluationStageMetrics> CocoObjectDetection::RunImpl() {
   // Process images in filename-sorted order.
   std::vector<std::string> image_paths;
   if (GetSortedFileNames(StripTrailingSlashes(ground_truth_images_path_),
@@ -224,8 +223,8 @@ void CocoObjectDetection::OutputResult(
                    << precision_metrics.overall_mean_average_precision();
 }
 
-std::unique_ptr<TaskExecutor> CreateTaskExecutor(int* argc, char* argv[]) {
-  return std::unique_ptr<TaskExecutor>(new CocoObjectDetection(argc, argv));
+std::unique_ptr<TaskExecutor> CreateTaskExecutor() {
+  return std::unique_ptr<TaskExecutor>(new CocoObjectDetection());
 }
 
 }  // namespace evaluation
diff --git a/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/BUILD b/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/BUILD
index de2a7f96311..941bbc0ff69 100644
--- a/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/BUILD
+++ b/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/BUILD
@@ -17,7 +17,6 @@ cc_library(
         "//tensorflow/lite/c:common",
         "//tensorflow/lite/tools:command_line_flags",
         "//tensorflow/lite/tools:logging",
-        "//tensorflow/lite/tools/evaluation:evaluation_delegate_provider",
         "//tensorflow/lite/tools/evaluation:evaluation_stage",
         "//tensorflow/lite/tools/evaluation:utils",
         "//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto",
diff --git a/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/run_eval.cc b/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/run_eval.cc
index 13eeb313ad4..fdc97d44abc 100644
--- a/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/run_eval.cc
+++ b/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/run_eval.cc
@@ -20,7 +20,6 @@ limitations under the License.
 #include "absl/types/optional.h"
 #include "tensorflow/lite/c/common.h"
 #include "tensorflow/lite/tools/command_line_flags.h"
-#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h"
 #include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
 #include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
 #include "tensorflow/lite/tools/evaluation/stages/image_classification_stage.h"
@@ -50,11 +49,14 @@ std::vector<T> GetFirstN(const std::vector<T>& v, int n) {
 
 class ImagenetClassification : public TaskExecutor {
  public:
-  ImagenetClassification(int* argc, char* argv[]);
+  ImagenetClassification() : num_images_(0), num_interpreter_threads_(1) {}
   ~ImagenetClassification() override {}
 
+ protected:
+  std::vector<Flag> GetFlags() final;
+
   // If the run is successful, the latest metrics will be returned.
-  absl::optional<EvaluationStageMetrics> Run() final;
+  absl::optional<EvaluationStageMetrics> RunImpl() final;
 
  private:
   void OutputResult(const EvaluationStageMetrics& latest_metrics) const;
@@ -67,11 +69,9 @@ class ImagenetClassification : public TaskExecutor {
   std::string delegate_;
   int num_images_;
   int num_interpreter_threads_;
-  DelegateProviders delegate_providers_;
 };
 
-ImagenetClassification::ImagenetClassification(int* argc, char* argv[])
-    : num_images_(0), num_interpreter_threads_(1) {
+std::vector<Flag> ImagenetClassification::GetFlags() {
   std::vector<tflite::Flag> flag_list = {
       tflite::Flag::CreateFlag(kModelFileFlag, &model_file_path_,
                                "Path to test tflite model file."),
@@ -107,11 +107,10 @@ ImagenetClassification::ImagenetClassification(int* argc, char* argv[])
           "Delegate to use for inference, if available. "
           "Must be one of {'nnapi', 'gpu', 'hexagon', 'xnnpack'}"),
   };
-  tflite::Flags::Parse(argc, const_cast<const char**>(argv), flag_list);
-  delegate_providers_.InitFromCmdlineArgs(argc, const_cast<const char**>(argv));
+  return flag_list;
 }
 
-absl::optional<EvaluationStageMetrics> ImagenetClassification::Run() {
+absl::optional<EvaluationStageMetrics> ImagenetClassification::RunImpl() {
   // Process images in filename-sorted order.
   std::vector<std::string> image_files, ground_truth_image_labels;
   if (GetSortedFileNames(StripTrailingSlashes(ground_truth_images_path_),
@@ -203,8 +202,8 @@ void ImagenetClassification::OutputResult(
   }
 }
 
-std::unique_ptr<TaskExecutor> CreateTaskExecutor(int* argc, char* argv[]) {
-  return std::unique_ptr<TaskExecutor>(new ImagenetClassification(argc, argv));
+std::unique_ptr<TaskExecutor> CreateTaskExecutor() {
+  return std::unique_ptr<TaskExecutor>(new ImagenetClassification());
 }
 
 }  // namespace evaluation
diff --git a/tensorflow/lite/tools/evaluation/tasks/inference_diff/BUILD b/tensorflow/lite/tools/evaluation/tasks/inference_diff/BUILD
index a53872b50cb..36606722caf 100644
--- a/tensorflow/lite/tools/evaluation/tasks/inference_diff/BUILD
+++ b/tensorflow/lite/tools/evaluation/tasks/inference_diff/BUILD
@@ -17,7 +17,6 @@ cc_library(
         "//tensorflow/lite/c:common",
         "//tensorflow/lite/tools:command_line_flags",
         "//tensorflow/lite/tools:logging",
-        "//tensorflow/lite/tools/evaluation:evaluation_delegate_provider",
         "//tensorflow/lite/tools/evaluation:evaluation_stage",
         "//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto",
         "//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto",
diff --git a/tensorflow/lite/tools/evaluation/tasks/inference_diff/run_eval.cc b/tensorflow/lite/tools/evaluation/tasks/inference_diff/run_eval.cc
index 814ebe3b3bf..9a3fea0b8a3 100644
--- a/tensorflow/lite/tools/evaluation/tasks/inference_diff/run_eval.cc
+++ b/tensorflow/lite/tools/evaluation/tasks/inference_diff/run_eval.cc
@@ -19,7 +19,6 @@ limitations under the License.
 #include "absl/types/optional.h"
 #include "tensorflow/lite/c/common.h"
 #include "tensorflow/lite/tools/command_line_flags.h"
-#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h"
 #include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
 #include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
 #include "tensorflow/lite/tools/evaluation/stages/inference_profiler_stage.h"
@@ -37,11 +36,14 @@ constexpr char kDelegateFlag[] = "delegate";
 
 class InferenceDiff : public TaskExecutor {
  public:
-  InferenceDiff(int* argc, char* argv[]);
+  InferenceDiff() : num_runs_(50), num_interpreter_threads_(1) {}
   ~InferenceDiff() override {}
 
+ protected:
+  std::vector<Flag> GetFlags() final;
+
   // If the run is successful, the latest metrics will be returned.
-  absl::optional<EvaluationStageMetrics> Run() final;
+  absl::optional<EvaluationStageMetrics> RunImpl() final;
 
  private:
   void OutputResult(const EvaluationStageMetrics& latest_metrics) const;
@@ -50,11 +52,9 @@ class InferenceDiff : public TaskExecutor {
   std::string delegate_;
   int num_runs_;
   int num_interpreter_threads_;
-  DelegateProviders delegate_providers_;
 };
 
-InferenceDiff::InferenceDiff(int* argc, char* argv[])
-    : num_runs_(50), num_interpreter_threads_(1) {
+std::vector<Flag> InferenceDiff::GetFlags() {
   // Command Line Flags.
   std::vector<tflite::Flag> flag_list = {
       tflite::Flag::CreateFlag(kModelFileFlag, &model_file_path_,
@@ -72,11 +72,11 @@ InferenceDiff::InferenceDiff(int* argc, char* argv[])
           "Delegate to use for test inference, if available. "
           "Must be one of {'nnapi', 'gpu', 'hexagon', 'xnnpack'}"),
   };
-  tflite::Flags::Parse(argc, const_cast<const char**>(argv), flag_list);
-  delegate_providers_.InitFromCmdlineArgs(argc, const_cast<const char**>(argv));
+
+  return flag_list;
 }
 
-absl::optional<EvaluationStageMetrics> InferenceDiff::Run() {
+absl::optional<EvaluationStageMetrics> InferenceDiff::RunImpl() {
   // Initialize evaluation stage.
   EvaluationStageConfig eval_config;
   eval_config.set_name("inference_profiling");
@@ -137,8 +137,8 @@ void InferenceDiff::OutputResult(
   }
 }
 
-std::unique_ptr<TaskExecutor> CreateTaskExecutor(int* argc, char* argv[]) {
-  return std::unique_ptr<TaskExecutor>(new InferenceDiff(argc, argv));
+std::unique_ptr<TaskExecutor> CreateTaskExecutor() {
+  return std::unique_ptr<TaskExecutor>(new InferenceDiff());
 }
 
 }  // namespace evaluation
diff --git a/tensorflow/lite/tools/evaluation/tasks/task_executor.cc b/tensorflow/lite/tools/evaluation/tasks/task_executor.cc
new file mode 100644
index 00000000000..e62793dc6ff
--- /dev/null
+++ b/tensorflow/lite/tools/evaluation/tasks/task_executor.cc
@@ -0,0 +1,47 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/lite/tools/evaluation/tasks/task_executor.h"
+
+#include "absl/types/optional.h"
+#include "tensorflow/lite/tools/logging.h"
+
+namespace tflite {
+namespace evaluation {
+absl::optional<EvaluationStageMetrics> TaskExecutor::Run(int* argc,
+                                                         char* argv[]) {
+  auto flag_list = GetFlags();
+  bool parse_result =
+      tflite::Flags::Parse(argc, const_cast<const char**>(argv), flag_list);
+  if (!parse_result) {
+    std::string usage = Flags::Usage(argv[0], flag_list);
+    TFLITE_LOG(ERROR) << usage;
+    return absl::nullopt;
+  }
+  parse_result = delegate_providers_.InitFromCmdlineArgs(
+      argc, const_cast<const char**>(argv));
+  if (!parse_result) {
+    return absl::nullopt;
+  }
+
+  std::string unconsumed_args =
+      Flags::ArgsToString(*argc, const_cast<const char**>(argv));
+  if (!unconsumed_args.empty()) {
+    TFLITE_LOG(WARN) << "Unconsumed cmdline flags: " << unconsumed_args;
+  }
+
+  return RunImpl();
+}
+}  // namespace evaluation
+}  // namespace tflite
diff --git a/tensorflow/lite/tools/evaluation/tasks/task_executor.h b/tensorflow/lite/tools/evaluation/tasks/task_executor.h
index b50e7d6d03f..caa84283098 100644
--- a/tensorflow/lite/tools/evaluation/tasks/task_executor.h
+++ b/tensorflow/lite/tools/evaluation/tasks/task_executor.h
@@ -16,6 +16,8 @@ limitations under the License.
 #define TENSORFLOW_LITE_TOOLS_EVALUATION_TASKS_TASK_EXECUTOR_H_
 
 #include "absl/types/optional.h"
+#include "tensorflow/lite/tools/command_line_flags.h"
+#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h"
 #include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
 
 namespace tflite {
@@ -25,13 +27,22 @@ namespace evaluation {
 class TaskExecutor {
  public:
   virtual ~TaskExecutor() {}
+
   // If the run is successful, the latest metrics will be returned.
-  virtual absl::optional<EvaluationStageMetrics> Run() = 0;
+  absl::optional<EvaluationStageMetrics> Run(int* argc, char* argv[]);
+
+ protected:
+  // Returns a list of commandline flags that this task defines.
+  virtual std::vector<Flag> GetFlags() = 0;
+
+  virtual absl::optional<EvaluationStageMetrics> RunImpl() = 0;
+
+  DelegateProviders delegate_providers_;
 };
 
 // Just a declaration. In order to avoid the boilerpolate main-function code,
 // every evaluation task should define this function.
-std::unique_ptr<TaskExecutor> CreateTaskExecutor(int* argc, char* argv[]);
+std::unique_ptr<TaskExecutor> CreateTaskExecutor();
 }  // namespace evaluation
 }  // namespace tflite
 
diff --git a/tensorflow/lite/tools/evaluation/tasks/task_executor_main.cc b/tensorflow/lite/tools/evaluation/tasks/task_executor_main.cc
index 6ef1a6724b7..97f8e263659 100644
--- a/tensorflow/lite/tools/evaluation/tasks/task_executor_main.cc
+++ b/tensorflow/lite/tools/evaluation/tasks/task_executor_main.cc
@@ -18,12 +18,12 @@ limitations under the License.
 
 // This could serve as the main function for all eval tools.
 int main(int argc, char* argv[]) {
-  auto task_executor = tflite::evaluation::CreateTaskExecutor(&argc, argv);
+  auto task_executor = tflite::evaluation::CreateTaskExecutor();
   if (task_executor == nullptr) {
     TFLITE_LOG(ERROR) << "Could not create the task evaluation!";
     return EXIT_FAILURE;
   }
-  const auto metrics = task_executor->Run();
+  const auto metrics = task_executor->Run(&argc, argv);
   if (!metrics.has_value()) {
     TFLITE_LOG(ERROR) << "Could not run the task evaluation!";
     return EXIT_FAILURE;