From a8c7cc20794dffce281a0cbeeafae730d853a0f7 Mon Sep 17 00:00:00 2001 From: Chao Mei Date: Sun, 19 Apr 2020 20:41:28 -0700 Subject: [PATCH] Refactor the inference_diff and imagenet-accuracy evaluation task to become a utility library, and create a common main stub function for these evaluation tasks. PiperOrigin-RevId: 307327740 Change-Id: Id1c081fa3ad1796c6ff60441f9189a985148fb09 --- tensorflow/lite/tools/evaluation/tasks/BUILD | 32 +++ .../lite/tools/evaluation/tasks/build_def.bzl | 14 ++ .../tasks/imagenet_image_classification/BUILD | 30 ++- .../imagenet_image_classification/run_eval.cc | 230 +++++++++--------- .../evaluation/tasks/inference_diff/BUILD | 28 ++- .../tasks/inference_diff/run_eval.cc | 122 +++++----- .../tools/evaluation/tasks/task_executor.h | 38 +++ .../evaluation/tasks/task_executor_main.cc | 32 +++ 8 files changed, 333 insertions(+), 193 deletions(-) create mode 100644 tensorflow/lite/tools/evaluation/tasks/BUILD create mode 100644 tensorflow/lite/tools/evaluation/tasks/build_def.bzl create mode 100644 tensorflow/lite/tools/evaluation/tasks/task_executor.h create mode 100644 tensorflow/lite/tools/evaluation/tasks/task_executor_main.cc diff --git a/tensorflow/lite/tools/evaluation/tasks/BUILD b/tensorflow/lite/tools/evaluation/tasks/BUILD new file mode 100644 index 00000000000..d8daf170331 --- /dev/null +++ b/tensorflow/lite/tools/evaluation/tasks/BUILD @@ -0,0 +1,32 @@ +load("//tensorflow/lite:build_def.bzl", "tflite_copts") +load("//tensorflow/lite/tools/evaluation/tasks:build_def.bzl", "task_linkopts") + +package( + default_visibility = [ + "//visibility:public", + ], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "task_executor", + hdrs = ["task_executor.h"], + copts = tflite_copts(), + linkopts = task_linkopts(), + deps = [ + "//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto", + "@com_google_absl//absl/types:optional", + ], +) + +cc_library( + name = "task_executor_main", + srcs = ["task_executor_main.cc"], + copts = tflite_copts(), + linkopts = task_linkopts(), + deps = [ + ":task_executor", + "//tensorflow/lite/tools:logging", + "@com_google_absl//absl/types:optional", + ], +) diff --git a/tensorflow/lite/tools/evaluation/tasks/build_def.bzl b/tensorflow/lite/tools/evaluation/tasks/build_def.bzl new file mode 100644 index 00000000000..0d71b4436b2 --- /dev/null +++ b/tensorflow/lite/tools/evaluation/tasks/build_def.bzl @@ -0,0 +1,14 @@ +"""Common BUILD-related definitions across different tasks""" + +load("//tensorflow/lite:build_def.bzl", "tflite_linkopts") + +def task_linkopts(): + return tflite_linkopts() + select({ + "//tensorflow:android": [ + "-pie", # Android 5.0 and later supports only PIE + "-lm", # some builtin ops, e.g., tanh, need -lm + # Hexagon delegate libraries should be in /data/local/tmp + "-Wl,--rpath=/data/local/tmp/", + ], + "//conditions:default": [], + }) diff --git a/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/BUILD b/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/BUILD index 606b595ef29..04b12be83c2 100644 --- a/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/BUILD +++ b/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/BUILD @@ -1,4 +1,5 @@ -load("//tensorflow/lite:build_def.bzl", "tflite_copts", "tflite_linkopts") +load("//tensorflow/lite:build_def.bzl", "tflite_copts") +load("//tensorflow/lite/tools/evaluation/tasks:build_def.bzl", "task_linkopts") package( default_visibility = [ @@ -7,20 +8,11 @@ package( licenses = ["notice"], # Apache 2.0 ) -common_linkopts = tflite_linkopts() + select({ - "//tensorflow:android": [ - "-pie", # Android 5.0 and later supports only PIE - "-lm", # some builtin ops, e.g., tanh, need -lm - "-Wl,--rpath=/data/local/tmp/", # Hexagon delegate libraries should be in /data/local/tmp - ], - "//conditions:default": [], -}) - -cc_binary( - name = "run_eval", +cc_library( + name = "run_eval_lib", srcs = ["run_eval.cc"], copts = tflite_copts(), - linkopts = common_linkopts, + linkopts = task_linkopts(), deps = [ "//tensorflow/lite/c:common", "//tensorflow/lite/tools:command_line_flags", @@ -31,5 +23,17 @@ cc_binary( "//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto", "//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto", "//tensorflow/lite/tools/evaluation/stages:image_classification_stage", + "//tensorflow/lite/tools/evaluation/tasks:task_executor", + "@com_google_absl//absl/types:optional", + ], +) + +cc_binary( + name = "run_eval", + copts = tflite_copts(), + linkopts = task_linkopts(), + deps = [ + ":run_eval_lib", + "//tensorflow/lite/tools/evaluation/tasks:task_executor_main", ], ) diff --git a/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/run_eval.cc b/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/run_eval.cc index c9db76454bc..13eeb313ad4 100644 --- a/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/run_eval.cc +++ b/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/run_eval.cc @@ -17,12 +17,14 @@ limitations under the License. #include #include +#include "absl/types/optional.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/tools/command_line_flags.h" #include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h" #include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h" #include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h" #include "tensorflow/lite/tools/evaluation/stages/image_classification_stage.h" +#include "tensorflow/lite/tools/evaluation/tasks/task_executor.h" #include "tensorflow/lite/tools/evaluation/utils.h" #include "tensorflow/lite/tools/logging.h" @@ -46,49 +48,143 @@ std::vector GetFirstN(const std::vector& v, int n) { return result; } -bool EvaluateModel(const std::string& model_file_path, - const std::vector& image_labels, - const std::vector& model_labels, - std::string delegate, std::string output_file_path, - int num_interpreter_threads, - const DelegateProviders& delegate_providers) { +class ImagenetClassification : public TaskExecutor { + public: + ImagenetClassification(int* argc, char* argv[]); + ~ImagenetClassification() override {} + + // If the run is successful, the latest metrics will be returned. + absl::optional Run() final; + + private: + void OutputResult(const EvaluationStageMetrics& latest_metrics) const; + std::string model_file_path_; + std::string ground_truth_images_path_; + std::string ground_truth_labels_path_; + std::string model_output_labels_path_; + std::string blacklist_file_path_; + std::string output_file_path_; + std::string delegate_; + int num_images_; + int num_interpreter_threads_; + DelegateProviders delegate_providers_; +}; + +ImagenetClassification::ImagenetClassification(int* argc, char* argv[]) + : num_images_(0), num_interpreter_threads_(1) { + std::vector flag_list = { + tflite::Flag::CreateFlag(kModelFileFlag, &model_file_path_, + "Path to test tflite model file."), + tflite::Flag::CreateFlag( + kModelOutputLabelsFlag, &model_output_labels_path_, + "Path to labels that correspond to output of model." + " E.g. in case of mobilenet, this is the path to label " + "file where each label is in the same order as the output" + " of the model."), + tflite::Flag::CreateFlag( + kGroundTruthImagesPathFlag, &ground_truth_images_path_, + "Path to ground truth images. These will be evaluated in " + "alphabetical order of filename"), + tflite::Flag::CreateFlag( + kGroundTruthLabelsFlag, &ground_truth_labels_path_, + "Path to ground truth labels, corresponding to alphabetical ordering " + "of ground truth images."), + tflite::Flag::CreateFlag( + kBlacklistFilePathFlag, &blacklist_file_path_, + "Path to blacklist file (optional) where each line is a single " + "integer that is " + "equal to index number of blacklisted image."), + tflite::Flag::CreateFlag(kOutputFilePathFlag, &output_file_path_, + "File to output metrics proto to."), + tflite::Flag::CreateFlag(kNumImagesFlag, &num_images_, + "Number of examples to evaluate, pass 0 for all " + "examples. Default: 0"), + tflite::Flag::CreateFlag( + kInterpreterThreadsFlag, &num_interpreter_threads_, + "Number of interpreter threads to use for inference."), + tflite::Flag::CreateFlag( + kDelegateFlag, &delegate_, + "Delegate to use for inference, if available. " + "Must be one of {'nnapi', 'gpu', 'hexagon', 'xnnpack'}"), + }; + tflite::Flags::Parse(argc, const_cast(argv), flag_list); + delegate_providers_.InitFromCmdlineArgs(argc, const_cast(argv)); +} + +absl::optional ImagenetClassification::Run() { + // Process images in filename-sorted order. + std::vector image_files, ground_truth_image_labels; + if (GetSortedFileNames(StripTrailingSlashes(ground_truth_images_path_), + &image_files) != kTfLiteOk) { + return absl::nullopt; + } + if (!ReadFileLines(ground_truth_labels_path_, &ground_truth_image_labels)) { + TFLITE_LOG(ERROR) << "Could not read ground truth labels file"; + return absl::nullopt; + } + if (image_files.size() != ground_truth_image_labels.size()) { + TFLITE_LOG(ERROR) << "Number of images and ground truth labels is not same"; + return absl::nullopt; + } + std::vector image_labels; + image_labels.reserve(image_files.size()); + for (int i = 0; i < image_files.size(); i++) { + image_labels.push_back({image_files[i], ground_truth_image_labels[i]}); + } + + // Filter out blacklisted/unwanted images. + if (FilterBlackListedImages(blacklist_file_path_, &image_labels) != + kTfLiteOk) { + return absl::nullopt; + } + if (num_images_ > 0) { + image_labels = GetFirstN(image_labels, num_images_); + } + + std::vector model_labels; + if (!ReadFileLines(model_output_labels_path_, &model_labels)) { + TFLITE_LOG(ERROR) << "Could not read model output labels file"; + return absl::nullopt; + } + EvaluationStageConfig eval_config; eval_config.set_name("image_classification"); auto* classification_params = eval_config.mutable_specification() ->mutable_image_classification_params(); auto* inference_params = classification_params->mutable_inference_params(); - inference_params->set_model_file_path(model_file_path); - inference_params->set_num_threads(num_interpreter_threads); - inference_params->set_delegate(ParseStringToDelegateType(delegate)); - if (!delegate.empty() && - inference_params->delegate() == TfliteInferenceParams::NONE) { - TFLITE_LOG(WARN) << "Unsupported TFLite delegate: " << delegate; - return false; - } + inference_params->set_model_file_path(model_file_path_); + inference_params->set_num_threads(num_interpreter_threads_); + inference_params->set_delegate(ParseStringToDelegateType(delegate_)); classification_params->mutable_topk_accuracy_eval_params()->set_k(10); ImageClassificationStage eval(eval_config); eval.SetAllLabels(model_labels); - if (eval.Init(&delegate_providers) != kTfLiteOk) return false; + if (eval.Init(&delegate_providers_) != kTfLiteOk) return absl::nullopt; const int step = image_labels.size() / 100; for (int i = 0; i < image_labels.size(); ++i) { if (step > 1 && i % step == 0) { TFLITE_LOG(INFO) << "Evaluated: " << i / step << "%"; } - eval.SetInputs(image_labels[i].image, image_labels[i].label); - if (eval.Run() != kTfLiteOk) return false; + if (eval.Run() != kTfLiteOk) return absl::nullopt; } const auto latest_metrics = eval.LatestMetrics(); - if (!output_file_path.empty()) { + OutputResult(latest_metrics); + return absl::make_optional(latest_metrics); +} + +void ImagenetClassification::OutputResult( + const EvaluationStageMetrics& latest_metrics) const { + if (!output_file_path_.empty()) { std::ofstream metrics_ofile; - metrics_ofile.open(output_file_path, std::ios::out); + metrics_ofile.open(output_file_path_, std::ios::out); metrics_ofile << latest_metrics.SerializeAsString(); metrics_ofile.close(); } + TFLITE_LOG(INFO) << "Num evaluation runs: " << latest_metrics.num_runs(); const auto& metrics = latest_metrics.process_metrics().image_classification_metrics(); @@ -105,103 +201,11 @@ bool EvaluateModel(const std::string& model_file_path, TFLITE_LOG(INFO) << "Top-" << i + 1 << " Accuracy: " << accuracy_metrics.topk_accuracies(i); } - - return true; } -int Main(int argc, char* argv[]) { - // Command Line Flags. - std::string model_file_path; - std::string ground_truth_images_path; - std::string ground_truth_labels_path; - std::string model_output_labels_path; - std::string blacklist_file_path; - std::string output_file_path; - std::string delegate; - int num_images = 0; - int num_interpreter_threads = 1; - std::vector flag_list = { - tflite::Flag::CreateFlag(kModelFileFlag, &model_file_path, - "Path to test tflite model file."), - tflite::Flag::CreateFlag( - kModelOutputLabelsFlag, &model_output_labels_path, - "Path to labels that correspond to output of model." - " E.g. in case of mobilenet, this is the path to label " - "file where each label is in the same order as the output" - " of the model."), - tflite::Flag::CreateFlag( - kGroundTruthImagesPathFlag, &ground_truth_images_path, - "Path to ground truth images. These will be evaluated in " - "alphabetical order of filename"), - tflite::Flag::CreateFlag( - kGroundTruthLabelsFlag, &ground_truth_labels_path, - "Path to ground truth labels, corresponding to alphabetical ordering " - "of ground truth images."), - tflite::Flag::CreateFlag( - kBlacklistFilePathFlag, &blacklist_file_path, - "Path to blacklist file (optional) where each line is a single " - "integer that is " - "equal to index number of blacklisted image."), - tflite::Flag::CreateFlag(kOutputFilePathFlag, &output_file_path, - "File to output metrics proto to."), - tflite::Flag::CreateFlag(kNumImagesFlag, &num_images, - "Number of examples to evaluate, pass 0 for all " - "examples. Default: 0"), - tflite::Flag::CreateFlag( - kInterpreterThreadsFlag, &num_interpreter_threads, - "Number of interpreter threads to use for inference."), - tflite::Flag::CreateFlag(kDelegateFlag, &delegate, - "Delegate to use for inference, if available. " - "Must be one of {'nnapi', 'gpu'}"), - }; - tflite::Flags::Parse(&argc, const_cast(argv), flag_list); - DelegateProviders delegate_providers; - delegate_providers.InitFromCmdlineArgs(&argc, const_cast(argv)); - - // Process images in filename-sorted order. - std::vector image_files, ground_truth_image_labels; - TF_LITE_ENSURE_STATUS(GetSortedFileNames( - StripTrailingSlashes(ground_truth_images_path), &image_files)); - if (!ReadFileLines(ground_truth_labels_path, &ground_truth_image_labels)) { - TFLITE_LOG(ERROR) << "Could not read ground truth labels file"; - return EXIT_FAILURE; - } - if (image_files.size() != ground_truth_image_labels.size()) { - TFLITE_LOG(ERROR) << "Number of images and ground truth labels is not same"; - return EXIT_FAILURE; - } - std::vector image_labels; - image_labels.reserve(image_files.size()); - for (int i = 0; i < image_files.size(); i++) { - image_labels.push_back({image_files[i], ground_truth_image_labels[i]}); - } - - // Filter out blacklisted/unwanted images. - TF_LITE_ENSURE_STATUS( - FilterBlackListedImages(blacklist_file_path, &image_labels)); - if (num_images > 0) { - image_labels = GetFirstN(image_labels, num_images); - } - - std::vector model_labels; - if (!ReadFileLines(model_output_labels_path, &model_labels)) { - TFLITE_LOG(ERROR) << "Could not read model output labels file"; - return EXIT_FAILURE; - } - - if (!EvaluateModel(model_file_path, image_labels, model_labels, delegate, - output_file_path, num_interpreter_threads, - delegate_providers)) { - TFLITE_LOG(ERROR) << "Could not evaluate model"; - return EXIT_FAILURE; - } - - return EXIT_SUCCESS; +std::unique_ptr CreateTaskExecutor(int* argc, char* argv[]) { + return std::unique_ptr(new ImagenetClassification(argc, argv)); } } // namespace evaluation } // namespace tflite - -int main(int argc, char* argv[]) { - return tflite::evaluation::Main(argc, argv); -} diff --git a/tensorflow/lite/tools/evaluation/tasks/inference_diff/BUILD b/tensorflow/lite/tools/evaluation/tasks/inference_diff/BUILD index e4d2c7ddd77..7c0eef7da37 100644 --- a/tensorflow/lite/tools/evaluation/tasks/inference_diff/BUILD +++ b/tensorflow/lite/tools/evaluation/tasks/inference_diff/BUILD @@ -1,4 +1,5 @@ -load("//tensorflow/lite:build_def.bzl", "tflite_copts", "tflite_linkopts") +load("//tensorflow/lite:build_def.bzl", "tflite_copts") +load("//tensorflow/lite/tools/evaluation/tasks:build_def.bzl", "task_linkopts") package( default_visibility = [ @@ -7,18 +8,11 @@ package( licenses = ["notice"], # Apache 2.0 ) -cc_binary( - name = "run_eval", +cc_library( + name = "run_eval_lib", srcs = ["run_eval.cc"], copts = tflite_copts(), - linkopts = tflite_linkopts() + select({ - "//tensorflow:android": [ - "-pie", # Android 5.0 and later supports only PIE - "-lm", # some builtin ops, e.g., tanh, need -lm - "-Wl,--rpath=/data/local/tmp/", # Hexagon delegate libraries should be in /data/local/tmp - ], - "//conditions:default": [], - }), + linkopts = task_linkopts(), deps = [ "//tensorflow/lite/c:common", "//tensorflow/lite/tools:command_line_flags", @@ -28,5 +22,17 @@ cc_binary( "//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto", "//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto", "//tensorflow/lite/tools/evaluation/stages:inference_profiler_stage", + "//tensorflow/lite/tools/evaluation/tasks:task_executor", + "@com_google_absl//absl/types:optional", + ], +) + +cc_binary( + name = "run_eval", + copts = tflite_copts(), + linkopts = task_linkopts(), + deps = [ + ":run_eval_lib", + "//tensorflow/lite/tools/evaluation/tasks:task_executor_main", ], ) diff --git a/tensorflow/lite/tools/evaluation/tasks/inference_diff/run_eval.cc b/tensorflow/lite/tools/evaluation/tasks/inference_diff/run_eval.cc index 1dc89927760..814ebe3b3bf 100644 --- a/tensorflow/lite/tools/evaluation/tasks/inference_diff/run_eval.cc +++ b/tensorflow/lite/tools/evaluation/tasks/inference_diff/run_eval.cc @@ -16,12 +16,14 @@ limitations under the License. #include #include +#include "absl/types/optional.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/tools/command_line_flags.h" #include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h" #include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h" #include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h" #include "tensorflow/lite/tools/evaluation/stages/inference_profiler_stage.h" +#include "tensorflow/lite/tools/evaluation/tasks/task_executor.h" #include "tensorflow/lite/tools/logging.h" namespace tflite { @@ -33,43 +35,88 @@ constexpr char kNumRunsFlag[] = "num_runs"; constexpr char kInterpreterThreadsFlag[] = "num_interpreter_threads"; constexpr char kDelegateFlag[] = "delegate"; -bool EvaluateModel(const std::string& model_file_path, - const std::string& delegate, int num_runs, - const std::string& output_file_path, - int num_interpreter_threads, - const DelegateProviders& delegate_providers) { +class InferenceDiff : public TaskExecutor { + public: + InferenceDiff(int* argc, char* argv[]); + ~InferenceDiff() override {} + + // If the run is successful, the latest metrics will be returned. + absl::optional Run() final; + + private: + void OutputResult(const EvaluationStageMetrics& latest_metrics) const; + std::string model_file_path_; + std::string output_file_path_; + std::string delegate_; + int num_runs_; + int num_interpreter_threads_; + DelegateProviders delegate_providers_; +}; + +InferenceDiff::InferenceDiff(int* argc, char* argv[]) + : num_runs_(50), num_interpreter_threads_(1) { + // Command Line Flags. + std::vector flag_list = { + tflite::Flag::CreateFlag(kModelFileFlag, &model_file_path_, + "Path to test tflite model file."), + tflite::Flag::CreateFlag(kOutputFilePathFlag, &output_file_path_, + "File to output metrics proto to."), + tflite::Flag::CreateFlag(kNumRunsFlag, &num_runs_, + "Number of runs of test & reference inference " + "each. Default value: 50"), + tflite::Flag::CreateFlag( + kInterpreterThreadsFlag, &num_interpreter_threads_, + "Number of interpreter threads to use for test inference."), + tflite::Flag::CreateFlag( + kDelegateFlag, &delegate_, + "Delegate to use for test inference, if available. " + "Must be one of {'nnapi', 'gpu', 'hexagon', 'xnnpack'}"), + }; + tflite::Flags::Parse(argc, const_cast(argv), flag_list); + delegate_providers_.InitFromCmdlineArgs(argc, const_cast(argv)); +} + +absl::optional InferenceDiff::Run() { // Initialize evaluation stage. EvaluationStageConfig eval_config; eval_config.set_name("inference_profiling"); auto* inference_params = eval_config.mutable_specification()->mutable_tflite_inference_params(); - inference_params->set_model_file_path(model_file_path); - inference_params->set_num_threads(num_interpreter_threads); + inference_params->set_model_file_path(model_file_path_); + inference_params->set_num_threads(num_interpreter_threads_); // This ensures that latency measurement isn't hampered by the time spent in // generating random data. inference_params->set_invocations_per_run(3); - inference_params->set_delegate(ParseStringToDelegateType(delegate)); - if (!delegate.empty() && + inference_params->set_delegate(ParseStringToDelegateType(delegate_)); + if (!delegate_.empty() && inference_params->delegate() == TfliteInferenceParams::NONE) { - TFLITE_LOG(WARN) << "Unsupported TFLite delegate: " << delegate; - return false; + TFLITE_LOG(WARN) << "Unsupported TFLite delegate: " << delegate_; + return absl::nullopt; } + InferenceProfilerStage eval(eval_config); - if (eval.Init(&delegate_providers) != kTfLiteOk) return false; + if (eval.Init(&delegate_providers_) != kTfLiteOk) return absl::nullopt; // Run inference & check diff for specified number of runs. - for (int i = 0; i < num_runs; ++i) { - if (eval.Run() != kTfLiteOk) return false; + for (int i = 0; i < num_runs_; ++i) { + if (eval.Run() != kTfLiteOk) return absl::nullopt; } - // Output latency & diff metrics. const auto latest_metrics = eval.LatestMetrics(); - if (!output_file_path.empty()) { + OutputResult(latest_metrics); + return absl::make_optional(latest_metrics); +} + +void InferenceDiff::OutputResult( + const EvaluationStageMetrics& latest_metrics) const { + // Output latency & diff metrics. + if (!output_file_path_.empty()) { std::ofstream metrics_ofile; - metrics_ofile.open(output_file_path, std::ios::out); + metrics_ofile.open(output_file_path_, std::ios::out); metrics_ofile << latest_metrics.SerializeAsString(); metrics_ofile.close(); } + TFLITE_LOG(INFO) << "Num evaluation runs: " << latest_metrics.num_runs(); const auto& metrics = latest_metrics.process_metrics().inference_profiler_metrics(); @@ -88,48 +135,11 @@ bool EvaluateModel(const std::string& model_file_path, << "]: avg_error=" << error.avg_value() << ", std_dev=" << error.std_deviation(); } - return true; } -int Main(int argc, char* argv[]) { - // Command Line Flags. - std::string model_file_path; - std::string output_file_path; - std::string delegate; - int num_runs = 50; - int num_interpreter_threads = 1; - std::vector flag_list = { - tflite::Flag::CreateFlag(kModelFileFlag, &model_file_path, - "Path to test tflite model file."), - tflite::Flag::CreateFlag(kOutputFilePathFlag, &output_file_path, - "File to output metrics proto to."), - tflite::Flag::CreateFlag(kNumRunsFlag, &num_runs, - "Number of runs of test & reference inference " - "each. Default value: 50"), - tflite::Flag::CreateFlag( - kInterpreterThreadsFlag, &num_interpreter_threads, - "Number of interpreter threads to use for test inference."), - tflite::Flag::CreateFlag( - kDelegateFlag, &delegate, - "Delegate to use for test inference, if available. " - "Must be one of {'nnapi', 'gpu', 'hexagon'}"), - }; - tflite::Flags::Parse(&argc, const_cast(argv), flag_list); - - DelegateProviders delegate_providers; - delegate_providers.InitFromCmdlineArgs(&argc, const_cast(argv)); - if (!EvaluateModel(model_file_path, delegate, num_runs, output_file_path, - num_interpreter_threads, delegate_providers)) { - TFLITE_LOG(ERROR) << "Could not evaluate model!"; - return EXIT_FAILURE; - } - - return EXIT_SUCCESS; +std::unique_ptr CreateTaskExecutor(int* argc, char* argv[]) { + return std::unique_ptr(new InferenceDiff(argc, argv)); } } // namespace evaluation } // namespace tflite - -int main(int argc, char* argv[]) { - return tflite::evaluation::Main(argc, argv); -} diff --git a/tensorflow/lite/tools/evaluation/tasks/task_executor.h b/tensorflow/lite/tools/evaluation/tasks/task_executor.h new file mode 100644 index 00000000000..b50e7d6d03f --- /dev/null +++ b/tensorflow/lite/tools/evaluation/tasks/task_executor.h @@ -0,0 +1,38 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_TOOLS_EVALUATION_TASKS_TASK_EXECUTOR_H_ +#define TENSORFLOW_LITE_TOOLS_EVALUATION_TASKS_TASK_EXECUTOR_H_ + +#include "absl/types/optional.h" +#include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h" + +namespace tflite { +namespace evaluation { +// A common task execution API to avoid boilerpolate code in defining the main +// function. +class TaskExecutor { + public: + virtual ~TaskExecutor() {} + // If the run is successful, the latest metrics will be returned. + virtual absl::optional Run() = 0; +}; + +// Just a declaration. In order to avoid the boilerpolate main-function code, +// every evaluation task should define this function. +std::unique_ptr CreateTaskExecutor(int* argc, char* argv[]); +} // namespace evaluation +} // namespace tflite + +#endif // TENSORFLOW_LITE_TOOLS_EVALUATION_TASKS_TASK_EXECUTOR_H_ diff --git a/tensorflow/lite/tools/evaluation/tasks/task_executor_main.cc b/tensorflow/lite/tools/evaluation/tasks/task_executor_main.cc new file mode 100644 index 00000000000..6ef1a6724b7 --- /dev/null +++ b/tensorflow/lite/tools/evaluation/tasks/task_executor_main.cc @@ -0,0 +1,32 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "absl/types/optional.h" +#include "tensorflow/lite/tools/evaluation/tasks/task_executor.h" +#include "tensorflow/lite/tools/logging.h" + +// This could serve as the main function for all eval tools. +int main(int argc, char* argv[]) { + auto task_executor = tflite::evaluation::CreateTaskExecutor(&argc, argv); + if (task_executor == nullptr) { + TFLITE_LOG(ERROR) << "Could not create the task evaluation!"; + return EXIT_FAILURE; + } + const auto metrics = task_executor->Run(); + if (!metrics.has_value()) { + TFLITE_LOG(ERROR) << "Could not run the task evaluation!"; + return EXIT_FAILURE; + } + return EXIT_SUCCESS; +}