Refactor the inference_diff and imagenet-accuracy evaluation task to become a utility library, and create a common main stub function for these evaluation tasks.

PiperOrigin-RevId: 307327740
Change-Id: Id1c081fa3ad1796c6ff60441f9189a985148fb09
This commit is contained in:
Chao Mei 2020-04-19 20:41:28 -07:00 committed by TensorFlower Gardener
parent bbaaccaab2
commit a8c7cc2079
8 changed files with 333 additions and 193 deletions

View File

@ -0,0 +1,32 @@
load("//tensorflow/lite:build_def.bzl", "tflite_copts")
load("//tensorflow/lite/tools/evaluation/tasks:build_def.bzl", "task_linkopts")
package(
default_visibility = [
"//visibility:public",
],
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "task_executor",
hdrs = ["task_executor.h"],
copts = tflite_copts(),
linkopts = task_linkopts(),
deps = [
"//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto",
"@com_google_absl//absl/types:optional",
],
)
cc_library(
name = "task_executor_main",
srcs = ["task_executor_main.cc"],
copts = tflite_copts(),
linkopts = task_linkopts(),
deps = [
":task_executor",
"//tensorflow/lite/tools:logging",
"@com_google_absl//absl/types:optional",
],
)

View File

@ -0,0 +1,14 @@
"""Common BUILD-related definitions across different tasks"""
load("//tensorflow/lite:build_def.bzl", "tflite_linkopts")
def task_linkopts():
return tflite_linkopts() + select({
"//tensorflow:android": [
"-pie", # Android 5.0 and later supports only PIE
"-lm", # some builtin ops, e.g., tanh, need -lm
# Hexagon delegate libraries should be in /data/local/tmp
"-Wl,--rpath=/data/local/tmp/",
],
"//conditions:default": [],
})

View File

@ -1,4 +1,5 @@
load("//tensorflow/lite:build_def.bzl", "tflite_copts", "tflite_linkopts")
load("//tensorflow/lite:build_def.bzl", "tflite_copts")
load("//tensorflow/lite/tools/evaluation/tasks:build_def.bzl", "task_linkopts")
package(
default_visibility = [
@ -7,20 +8,11 @@ package(
licenses = ["notice"], # Apache 2.0
)
common_linkopts = tflite_linkopts() + select({
"//tensorflow:android": [
"-pie", # Android 5.0 and later supports only PIE
"-lm", # some builtin ops, e.g., tanh, need -lm
"-Wl,--rpath=/data/local/tmp/", # Hexagon delegate libraries should be in /data/local/tmp
],
"//conditions:default": [],
})
cc_binary(
name = "run_eval",
cc_library(
name = "run_eval_lib",
srcs = ["run_eval.cc"],
copts = tflite_copts(),
linkopts = common_linkopts,
linkopts = task_linkopts(),
deps = [
"//tensorflow/lite/c:common",
"//tensorflow/lite/tools:command_line_flags",
@ -31,5 +23,17 @@ cc_binary(
"//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto",
"//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto",
"//tensorflow/lite/tools/evaluation/stages:image_classification_stage",
"//tensorflow/lite/tools/evaluation/tasks:task_executor",
"@com_google_absl//absl/types:optional",
],
)
cc_binary(
name = "run_eval",
copts = tflite_copts(),
linkopts = task_linkopts(),
deps = [
":run_eval_lib",
"//tensorflow/lite/tools/evaluation/tasks:task_executor_main",
],
)

View File

@ -17,12 +17,14 @@ limitations under the License.
#include <string>
#include <vector>
#include "absl/types/optional.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/tools/command_line_flags.h"
#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h"
#include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
#include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
#include "tensorflow/lite/tools/evaluation/stages/image_classification_stage.h"
#include "tensorflow/lite/tools/evaluation/tasks/task_executor.h"
#include "tensorflow/lite/tools/evaluation/utils.h"
#include "tensorflow/lite/tools/logging.h"
@ -46,49 +48,143 @@ std::vector<T> GetFirstN(const std::vector<T>& v, int n) {
return result;
}
bool EvaluateModel(const std::string& model_file_path,
const std::vector<ImageLabel>& image_labels,
const std::vector<std::string>& model_labels,
std::string delegate, std::string output_file_path,
int num_interpreter_threads,
const DelegateProviders& delegate_providers) {
class ImagenetClassification : public TaskExecutor {
public:
ImagenetClassification(int* argc, char* argv[]);
~ImagenetClassification() override {}
// If the run is successful, the latest metrics will be returned.
absl::optional<EvaluationStageMetrics> Run() final;
private:
void OutputResult(const EvaluationStageMetrics& latest_metrics) const;
std::string model_file_path_;
std::string ground_truth_images_path_;
std::string ground_truth_labels_path_;
std::string model_output_labels_path_;
std::string blacklist_file_path_;
std::string output_file_path_;
std::string delegate_;
int num_images_;
int num_interpreter_threads_;
DelegateProviders delegate_providers_;
};
ImagenetClassification::ImagenetClassification(int* argc, char* argv[])
: num_images_(0), num_interpreter_threads_(1) {
std::vector<tflite::Flag> flag_list = {
tflite::Flag::CreateFlag(kModelFileFlag, &model_file_path_,
"Path to test tflite model file."),
tflite::Flag::CreateFlag(
kModelOutputLabelsFlag, &model_output_labels_path_,
"Path to labels that correspond to output of model."
" E.g. in case of mobilenet, this is the path to label "
"file where each label is in the same order as the output"
" of the model."),
tflite::Flag::CreateFlag(
kGroundTruthImagesPathFlag, &ground_truth_images_path_,
"Path to ground truth images. These will be evaluated in "
"alphabetical order of filename"),
tflite::Flag::CreateFlag(
kGroundTruthLabelsFlag, &ground_truth_labels_path_,
"Path to ground truth labels, corresponding to alphabetical ordering "
"of ground truth images."),
tflite::Flag::CreateFlag(
kBlacklistFilePathFlag, &blacklist_file_path_,
"Path to blacklist file (optional) where each line is a single "
"integer that is "
"equal to index number of blacklisted image."),
tflite::Flag::CreateFlag(kOutputFilePathFlag, &output_file_path_,
"File to output metrics proto to."),
tflite::Flag::CreateFlag(kNumImagesFlag, &num_images_,
"Number of examples to evaluate, pass 0 for all "
"examples. Default: 0"),
tflite::Flag::CreateFlag(
kInterpreterThreadsFlag, &num_interpreter_threads_,
"Number of interpreter threads to use for inference."),
tflite::Flag::CreateFlag(
kDelegateFlag, &delegate_,
"Delegate to use for inference, if available. "
"Must be one of {'nnapi', 'gpu', 'hexagon', 'xnnpack'}"),
};
tflite::Flags::Parse(argc, const_cast<const char**>(argv), flag_list);
delegate_providers_.InitFromCmdlineArgs(argc, const_cast<const char**>(argv));
}
absl::optional<EvaluationStageMetrics> ImagenetClassification::Run() {
// Process images in filename-sorted order.
std::vector<std::string> image_files, ground_truth_image_labels;
if (GetSortedFileNames(StripTrailingSlashes(ground_truth_images_path_),
&image_files) != kTfLiteOk) {
return absl::nullopt;
}
if (!ReadFileLines(ground_truth_labels_path_, &ground_truth_image_labels)) {
TFLITE_LOG(ERROR) << "Could not read ground truth labels file";
return absl::nullopt;
}
if (image_files.size() != ground_truth_image_labels.size()) {
TFLITE_LOG(ERROR) << "Number of images and ground truth labels is not same";
return absl::nullopt;
}
std::vector<ImageLabel> image_labels;
image_labels.reserve(image_files.size());
for (int i = 0; i < image_files.size(); i++) {
image_labels.push_back({image_files[i], ground_truth_image_labels[i]});
}
// Filter out blacklisted/unwanted images.
if (FilterBlackListedImages(blacklist_file_path_, &image_labels) !=
kTfLiteOk) {
return absl::nullopt;
}
if (num_images_ > 0) {
image_labels = GetFirstN(image_labels, num_images_);
}
std::vector<std::string> model_labels;
if (!ReadFileLines(model_output_labels_path_, &model_labels)) {
TFLITE_LOG(ERROR) << "Could not read model output labels file";
return absl::nullopt;
}
EvaluationStageConfig eval_config;
eval_config.set_name("image_classification");
auto* classification_params = eval_config.mutable_specification()
->mutable_image_classification_params();
auto* inference_params = classification_params->mutable_inference_params();
inference_params->set_model_file_path(model_file_path);
inference_params->set_num_threads(num_interpreter_threads);
inference_params->set_delegate(ParseStringToDelegateType(delegate));
if (!delegate.empty() &&
inference_params->delegate() == TfliteInferenceParams::NONE) {
TFLITE_LOG(WARN) << "Unsupported TFLite delegate: " << delegate;
return false;
}
inference_params->set_model_file_path(model_file_path_);
inference_params->set_num_threads(num_interpreter_threads_);
inference_params->set_delegate(ParseStringToDelegateType(delegate_));
classification_params->mutable_topk_accuracy_eval_params()->set_k(10);
ImageClassificationStage eval(eval_config);
eval.SetAllLabels(model_labels);
if (eval.Init(&delegate_providers) != kTfLiteOk) return false;
if (eval.Init(&delegate_providers_) != kTfLiteOk) return absl::nullopt;
const int step = image_labels.size() / 100;
for (int i = 0; i < image_labels.size(); ++i) {
if (step > 1 && i % step == 0) {
TFLITE_LOG(INFO) << "Evaluated: " << i / step << "%";
}
eval.SetInputs(image_labels[i].image, image_labels[i].label);
if (eval.Run() != kTfLiteOk) return false;
if (eval.Run() != kTfLiteOk) return absl::nullopt;
}
const auto latest_metrics = eval.LatestMetrics();
if (!output_file_path.empty()) {
OutputResult(latest_metrics);
return absl::make_optional(latest_metrics);
}
void ImagenetClassification::OutputResult(
const EvaluationStageMetrics& latest_metrics) const {
if (!output_file_path_.empty()) {
std::ofstream metrics_ofile;
metrics_ofile.open(output_file_path, std::ios::out);
metrics_ofile.open(output_file_path_, std::ios::out);
metrics_ofile << latest_metrics.SerializeAsString();
metrics_ofile.close();
}
TFLITE_LOG(INFO) << "Num evaluation runs: " << latest_metrics.num_runs();
const auto& metrics =
latest_metrics.process_metrics().image_classification_metrics();
@ -105,103 +201,11 @@ bool EvaluateModel(const std::string& model_file_path,
TFLITE_LOG(INFO) << "Top-" << i + 1
<< " Accuracy: " << accuracy_metrics.topk_accuracies(i);
}
return true;
}
int Main(int argc, char* argv[]) {
// Command Line Flags.
std::string model_file_path;
std::string ground_truth_images_path;
std::string ground_truth_labels_path;
std::string model_output_labels_path;
std::string blacklist_file_path;
std::string output_file_path;
std::string delegate;
int num_images = 0;
int num_interpreter_threads = 1;
std::vector<tflite::Flag> flag_list = {
tflite::Flag::CreateFlag(kModelFileFlag, &model_file_path,
"Path to test tflite model file."),
tflite::Flag::CreateFlag(
kModelOutputLabelsFlag, &model_output_labels_path,
"Path to labels that correspond to output of model."
" E.g. in case of mobilenet, this is the path to label "
"file where each label is in the same order as the output"
" of the model."),
tflite::Flag::CreateFlag(
kGroundTruthImagesPathFlag, &ground_truth_images_path,
"Path to ground truth images. These will be evaluated in "
"alphabetical order of filename"),
tflite::Flag::CreateFlag(
kGroundTruthLabelsFlag, &ground_truth_labels_path,
"Path to ground truth labels, corresponding to alphabetical ordering "
"of ground truth images."),
tflite::Flag::CreateFlag(
kBlacklistFilePathFlag, &blacklist_file_path,
"Path to blacklist file (optional) where each line is a single "
"integer that is "
"equal to index number of blacklisted image."),
tflite::Flag::CreateFlag(kOutputFilePathFlag, &output_file_path,
"File to output metrics proto to."),
tflite::Flag::CreateFlag(kNumImagesFlag, &num_images,
"Number of examples to evaluate, pass 0 for all "
"examples. Default: 0"),
tflite::Flag::CreateFlag(
kInterpreterThreadsFlag, &num_interpreter_threads,
"Number of interpreter threads to use for inference."),
tflite::Flag::CreateFlag(kDelegateFlag, &delegate,
"Delegate to use for inference, if available. "
"Must be one of {'nnapi', 'gpu'}"),
};
tflite::Flags::Parse(&argc, const_cast<const char**>(argv), flag_list);
DelegateProviders delegate_providers;
delegate_providers.InitFromCmdlineArgs(&argc, const_cast<const char**>(argv));
// Process images in filename-sorted order.
std::vector<std::string> image_files, ground_truth_image_labels;
TF_LITE_ENSURE_STATUS(GetSortedFileNames(
StripTrailingSlashes(ground_truth_images_path), &image_files));
if (!ReadFileLines(ground_truth_labels_path, &ground_truth_image_labels)) {
TFLITE_LOG(ERROR) << "Could not read ground truth labels file";
return EXIT_FAILURE;
}
if (image_files.size() != ground_truth_image_labels.size()) {
TFLITE_LOG(ERROR) << "Number of images and ground truth labels is not same";
return EXIT_FAILURE;
}
std::vector<ImageLabel> image_labels;
image_labels.reserve(image_files.size());
for (int i = 0; i < image_files.size(); i++) {
image_labels.push_back({image_files[i], ground_truth_image_labels[i]});
}
// Filter out blacklisted/unwanted images.
TF_LITE_ENSURE_STATUS(
FilterBlackListedImages(blacklist_file_path, &image_labels));
if (num_images > 0) {
image_labels = GetFirstN(image_labels, num_images);
}
std::vector<std::string> model_labels;
if (!ReadFileLines(model_output_labels_path, &model_labels)) {
TFLITE_LOG(ERROR) << "Could not read model output labels file";
return EXIT_FAILURE;
}
if (!EvaluateModel(model_file_path, image_labels, model_labels, delegate,
output_file_path, num_interpreter_threads,
delegate_providers)) {
TFLITE_LOG(ERROR) << "Could not evaluate model";
return EXIT_FAILURE;
}
return EXIT_SUCCESS;
std::unique_ptr<TaskExecutor> CreateTaskExecutor(int* argc, char* argv[]) {
return std::unique_ptr<TaskExecutor>(new ImagenetClassification(argc, argv));
}
} // namespace evaluation
} // namespace tflite
int main(int argc, char* argv[]) {
return tflite::evaluation::Main(argc, argv);
}

View File

@ -1,4 +1,5 @@
load("//tensorflow/lite:build_def.bzl", "tflite_copts", "tflite_linkopts")
load("//tensorflow/lite:build_def.bzl", "tflite_copts")
load("//tensorflow/lite/tools/evaluation/tasks:build_def.bzl", "task_linkopts")
package(
default_visibility = [
@ -7,18 +8,11 @@ package(
licenses = ["notice"], # Apache 2.0
)
cc_binary(
name = "run_eval",
cc_library(
name = "run_eval_lib",
srcs = ["run_eval.cc"],
copts = tflite_copts(),
linkopts = tflite_linkopts() + select({
"//tensorflow:android": [
"-pie", # Android 5.0 and later supports only PIE
"-lm", # some builtin ops, e.g., tanh, need -lm
"-Wl,--rpath=/data/local/tmp/", # Hexagon delegate libraries should be in /data/local/tmp
],
"//conditions:default": [],
}),
linkopts = task_linkopts(),
deps = [
"//tensorflow/lite/c:common",
"//tensorflow/lite/tools:command_line_flags",
@ -28,5 +22,17 @@ cc_binary(
"//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto",
"//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto",
"//tensorflow/lite/tools/evaluation/stages:inference_profiler_stage",
"//tensorflow/lite/tools/evaluation/tasks:task_executor",
"@com_google_absl//absl/types:optional",
],
)
cc_binary(
name = "run_eval",
copts = tflite_copts(),
linkopts = task_linkopts(),
deps = [
":run_eval_lib",
"//tensorflow/lite/tools/evaluation/tasks:task_executor_main",
],
)

View File

@ -16,12 +16,14 @@ limitations under the License.
#include <string>
#include <vector>
#include "absl/types/optional.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/tools/command_line_flags.h"
#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h"
#include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
#include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
#include "tensorflow/lite/tools/evaluation/stages/inference_profiler_stage.h"
#include "tensorflow/lite/tools/evaluation/tasks/task_executor.h"
#include "tensorflow/lite/tools/logging.h"
namespace tflite {
@ -33,43 +35,88 @@ constexpr char kNumRunsFlag[] = "num_runs";
constexpr char kInterpreterThreadsFlag[] = "num_interpreter_threads";
constexpr char kDelegateFlag[] = "delegate";
bool EvaluateModel(const std::string& model_file_path,
const std::string& delegate, int num_runs,
const std::string& output_file_path,
int num_interpreter_threads,
const DelegateProviders& delegate_providers) {
class InferenceDiff : public TaskExecutor {
public:
InferenceDiff(int* argc, char* argv[]);
~InferenceDiff() override {}
// If the run is successful, the latest metrics will be returned.
absl::optional<EvaluationStageMetrics> Run() final;
private:
void OutputResult(const EvaluationStageMetrics& latest_metrics) const;
std::string model_file_path_;
std::string output_file_path_;
std::string delegate_;
int num_runs_;
int num_interpreter_threads_;
DelegateProviders delegate_providers_;
};
InferenceDiff::InferenceDiff(int* argc, char* argv[])
: num_runs_(50), num_interpreter_threads_(1) {
// Command Line Flags.
std::vector<tflite::Flag> flag_list = {
tflite::Flag::CreateFlag(kModelFileFlag, &model_file_path_,
"Path to test tflite model file."),
tflite::Flag::CreateFlag(kOutputFilePathFlag, &output_file_path_,
"File to output metrics proto to."),
tflite::Flag::CreateFlag(kNumRunsFlag, &num_runs_,
"Number of runs of test & reference inference "
"each. Default value: 50"),
tflite::Flag::CreateFlag(
kInterpreterThreadsFlag, &num_interpreter_threads_,
"Number of interpreter threads to use for test inference."),
tflite::Flag::CreateFlag(
kDelegateFlag, &delegate_,
"Delegate to use for test inference, if available. "
"Must be one of {'nnapi', 'gpu', 'hexagon', 'xnnpack'}"),
};
tflite::Flags::Parse(argc, const_cast<const char**>(argv), flag_list);
delegate_providers_.InitFromCmdlineArgs(argc, const_cast<const char**>(argv));
}
absl::optional<EvaluationStageMetrics> InferenceDiff::Run() {
// Initialize evaluation stage.
EvaluationStageConfig eval_config;
eval_config.set_name("inference_profiling");
auto* inference_params =
eval_config.mutable_specification()->mutable_tflite_inference_params();
inference_params->set_model_file_path(model_file_path);
inference_params->set_num_threads(num_interpreter_threads);
inference_params->set_model_file_path(model_file_path_);
inference_params->set_num_threads(num_interpreter_threads_);
// This ensures that latency measurement isn't hampered by the time spent in
// generating random data.
inference_params->set_invocations_per_run(3);
inference_params->set_delegate(ParseStringToDelegateType(delegate));
if (!delegate.empty() &&
inference_params->set_delegate(ParseStringToDelegateType(delegate_));
if (!delegate_.empty() &&
inference_params->delegate() == TfliteInferenceParams::NONE) {
TFLITE_LOG(WARN) << "Unsupported TFLite delegate: " << delegate;
return false;
TFLITE_LOG(WARN) << "Unsupported TFLite delegate: " << delegate_;
return absl::nullopt;
}
InferenceProfilerStage eval(eval_config);
if (eval.Init(&delegate_providers) != kTfLiteOk) return false;
if (eval.Init(&delegate_providers_) != kTfLiteOk) return absl::nullopt;
// Run inference & check diff for specified number of runs.
for (int i = 0; i < num_runs; ++i) {
if (eval.Run() != kTfLiteOk) return false;
for (int i = 0; i < num_runs_; ++i) {
if (eval.Run() != kTfLiteOk) return absl::nullopt;
}
// Output latency & diff metrics.
const auto latest_metrics = eval.LatestMetrics();
if (!output_file_path.empty()) {
OutputResult(latest_metrics);
return absl::make_optional(latest_metrics);
}
void InferenceDiff::OutputResult(
const EvaluationStageMetrics& latest_metrics) const {
// Output latency & diff metrics.
if (!output_file_path_.empty()) {
std::ofstream metrics_ofile;
metrics_ofile.open(output_file_path, std::ios::out);
metrics_ofile.open(output_file_path_, std::ios::out);
metrics_ofile << latest_metrics.SerializeAsString();
metrics_ofile.close();
}
TFLITE_LOG(INFO) << "Num evaluation runs: " << latest_metrics.num_runs();
const auto& metrics =
latest_metrics.process_metrics().inference_profiler_metrics();
@ -88,48 +135,11 @@ bool EvaluateModel(const std::string& model_file_path,
<< "]: avg_error=" << error.avg_value()
<< ", std_dev=" << error.std_deviation();
}
return true;
}
int Main(int argc, char* argv[]) {
// Command Line Flags.
std::string model_file_path;
std::string output_file_path;
std::string delegate;
int num_runs = 50;
int num_interpreter_threads = 1;
std::vector<tflite::Flag> flag_list = {
tflite::Flag::CreateFlag(kModelFileFlag, &model_file_path,
"Path to test tflite model file."),
tflite::Flag::CreateFlag(kOutputFilePathFlag, &output_file_path,
"File to output metrics proto to."),
tflite::Flag::CreateFlag(kNumRunsFlag, &num_runs,
"Number of runs of test & reference inference "
"each. Default value: 50"),
tflite::Flag::CreateFlag(
kInterpreterThreadsFlag, &num_interpreter_threads,
"Number of interpreter threads to use for test inference."),
tflite::Flag::CreateFlag(
kDelegateFlag, &delegate,
"Delegate to use for test inference, if available. "
"Must be one of {'nnapi', 'gpu', 'hexagon'}"),
};
tflite::Flags::Parse(&argc, const_cast<const char**>(argv), flag_list);
DelegateProviders delegate_providers;
delegate_providers.InitFromCmdlineArgs(&argc, const_cast<const char**>(argv));
if (!EvaluateModel(model_file_path, delegate, num_runs, output_file_path,
num_interpreter_threads, delegate_providers)) {
TFLITE_LOG(ERROR) << "Could not evaluate model!";
return EXIT_FAILURE;
}
return EXIT_SUCCESS;
std::unique_ptr<TaskExecutor> CreateTaskExecutor(int* argc, char* argv[]) {
return std::unique_ptr<TaskExecutor>(new InferenceDiff(argc, argv));
}
} // namespace evaluation
} // namespace tflite
int main(int argc, char* argv[]) {
return tflite::evaluation::Main(argc, argv);
}

View File

@ -0,0 +1,38 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_TOOLS_EVALUATION_TASKS_TASK_EXECUTOR_H_
#define TENSORFLOW_LITE_TOOLS_EVALUATION_TASKS_TASK_EXECUTOR_H_
#include "absl/types/optional.h"
#include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
namespace tflite {
namespace evaluation {
// A common task execution API to avoid boilerpolate code in defining the main
// function.
class TaskExecutor {
public:
virtual ~TaskExecutor() {}
// If the run is successful, the latest metrics will be returned.
virtual absl::optional<EvaluationStageMetrics> Run() = 0;
};
// Just a declaration. In order to avoid the boilerpolate main-function code,
// every evaluation task should define this function.
std::unique_ptr<TaskExecutor> CreateTaskExecutor(int* argc, char* argv[]);
} // namespace evaluation
} // namespace tflite
#endif // TENSORFLOW_LITE_TOOLS_EVALUATION_TASKS_TASK_EXECUTOR_H_

View File

@ -0,0 +1,32 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "absl/types/optional.h"
#include "tensorflow/lite/tools/evaluation/tasks/task_executor.h"
#include "tensorflow/lite/tools/logging.h"
// This could serve as the main function for all eval tools.
int main(int argc, char* argv[]) {
auto task_executor = tflite::evaluation::CreateTaskExecutor(&argc, argv);
if (task_executor == nullptr) {
TFLITE_LOG(ERROR) << "Could not create the task evaluation!";
return EXIT_FAILURE;
}
const auto metrics = task_executor->Run();
if (!metrics.has_value()) {
TFLITE_LOG(ERROR) << "Could not run the task evaluation!";
return EXIT_FAILURE;
}
return EXIT_SUCCESS;
}