Refactor the inference_diff and imagenet-accuracy evaluation task to become a utility library, and create a common main stub function for these evaluation tasks.
PiperOrigin-RevId: 307327740 Change-Id: Id1c081fa3ad1796c6ff60441f9189a985148fb09
This commit is contained in:
parent
bbaaccaab2
commit
a8c7cc2079
32
tensorflow/lite/tools/evaluation/tasks/BUILD
Normal file
32
tensorflow/lite/tools/evaluation/tasks/BUILD
Normal file
@ -0,0 +1,32 @@
|
||||
load("//tensorflow/lite:build_def.bzl", "tflite_copts")
|
||||
load("//tensorflow/lite/tools/evaluation/tasks:build_def.bzl", "task_linkopts")
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "task_executor",
|
||||
hdrs = ["task_executor.h"],
|
||||
copts = tflite_copts(),
|
||||
linkopts = task_linkopts(),
|
||||
deps = [
|
||||
"//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "task_executor_main",
|
||||
srcs = ["task_executor_main.cc"],
|
||||
copts = tflite_copts(),
|
||||
linkopts = task_linkopts(),
|
||||
deps = [
|
||||
":task_executor",
|
||||
"//tensorflow/lite/tools:logging",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
14
tensorflow/lite/tools/evaluation/tasks/build_def.bzl
Normal file
14
tensorflow/lite/tools/evaluation/tasks/build_def.bzl
Normal file
@ -0,0 +1,14 @@
|
||||
"""Common BUILD-related definitions across different tasks"""
|
||||
|
||||
load("//tensorflow/lite:build_def.bzl", "tflite_linkopts")
|
||||
|
||||
def task_linkopts():
|
||||
return tflite_linkopts() + select({
|
||||
"//tensorflow:android": [
|
||||
"-pie", # Android 5.0 and later supports only PIE
|
||||
"-lm", # some builtin ops, e.g., tanh, need -lm
|
||||
# Hexagon delegate libraries should be in /data/local/tmp
|
||||
"-Wl,--rpath=/data/local/tmp/",
|
||||
],
|
||||
"//conditions:default": [],
|
||||
})
|
@ -1,4 +1,5 @@
|
||||
load("//tensorflow/lite:build_def.bzl", "tflite_copts", "tflite_linkopts")
|
||||
load("//tensorflow/lite:build_def.bzl", "tflite_copts")
|
||||
load("//tensorflow/lite/tools/evaluation/tasks:build_def.bzl", "task_linkopts")
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
@ -7,20 +8,11 @@ package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
common_linkopts = tflite_linkopts() + select({
|
||||
"//tensorflow:android": [
|
||||
"-pie", # Android 5.0 and later supports only PIE
|
||||
"-lm", # some builtin ops, e.g., tanh, need -lm
|
||||
"-Wl,--rpath=/data/local/tmp/", # Hexagon delegate libraries should be in /data/local/tmp
|
||||
],
|
||||
"//conditions:default": [],
|
||||
})
|
||||
|
||||
cc_binary(
|
||||
name = "run_eval",
|
||||
cc_library(
|
||||
name = "run_eval_lib",
|
||||
srcs = ["run_eval.cc"],
|
||||
copts = tflite_copts(),
|
||||
linkopts = common_linkopts,
|
||||
linkopts = task_linkopts(),
|
||||
deps = [
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/tools:command_line_flags",
|
||||
@ -31,5 +23,17 @@ cc_binary(
|
||||
"//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto",
|
||||
"//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto",
|
||||
"//tensorflow/lite/tools/evaluation/stages:image_classification_stage",
|
||||
"//tensorflow/lite/tools/evaluation/tasks:task_executor",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "run_eval",
|
||||
copts = tflite_copts(),
|
||||
linkopts = task_linkopts(),
|
||||
deps = [
|
||||
":run_eval_lib",
|
||||
"//tensorflow/lite/tools/evaluation/tasks:task_executor_main",
|
||||
],
|
||||
)
|
||||
|
@ -17,12 +17,14 @@ limitations under the License.
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/tools/command_line_flags.h"
|
||||
#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h"
|
||||
#include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
|
||||
#include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
|
||||
#include "tensorflow/lite/tools/evaluation/stages/image_classification_stage.h"
|
||||
#include "tensorflow/lite/tools/evaluation/tasks/task_executor.h"
|
||||
#include "tensorflow/lite/tools/evaluation/utils.h"
|
||||
#include "tensorflow/lite/tools/logging.h"
|
||||
|
||||
@ -46,49 +48,143 @@ std::vector<T> GetFirstN(const std::vector<T>& v, int n) {
|
||||
return result;
|
||||
}
|
||||
|
||||
bool EvaluateModel(const std::string& model_file_path,
|
||||
const std::vector<ImageLabel>& image_labels,
|
||||
const std::vector<std::string>& model_labels,
|
||||
std::string delegate, std::string output_file_path,
|
||||
int num_interpreter_threads,
|
||||
const DelegateProviders& delegate_providers) {
|
||||
class ImagenetClassification : public TaskExecutor {
|
||||
public:
|
||||
ImagenetClassification(int* argc, char* argv[]);
|
||||
~ImagenetClassification() override {}
|
||||
|
||||
// If the run is successful, the latest metrics will be returned.
|
||||
absl::optional<EvaluationStageMetrics> Run() final;
|
||||
|
||||
private:
|
||||
void OutputResult(const EvaluationStageMetrics& latest_metrics) const;
|
||||
std::string model_file_path_;
|
||||
std::string ground_truth_images_path_;
|
||||
std::string ground_truth_labels_path_;
|
||||
std::string model_output_labels_path_;
|
||||
std::string blacklist_file_path_;
|
||||
std::string output_file_path_;
|
||||
std::string delegate_;
|
||||
int num_images_;
|
||||
int num_interpreter_threads_;
|
||||
DelegateProviders delegate_providers_;
|
||||
};
|
||||
|
||||
ImagenetClassification::ImagenetClassification(int* argc, char* argv[])
|
||||
: num_images_(0), num_interpreter_threads_(1) {
|
||||
std::vector<tflite::Flag> flag_list = {
|
||||
tflite::Flag::CreateFlag(kModelFileFlag, &model_file_path_,
|
||||
"Path to test tflite model file."),
|
||||
tflite::Flag::CreateFlag(
|
||||
kModelOutputLabelsFlag, &model_output_labels_path_,
|
||||
"Path to labels that correspond to output of model."
|
||||
" E.g. in case of mobilenet, this is the path to label "
|
||||
"file where each label is in the same order as the output"
|
||||
" of the model."),
|
||||
tflite::Flag::CreateFlag(
|
||||
kGroundTruthImagesPathFlag, &ground_truth_images_path_,
|
||||
"Path to ground truth images. These will be evaluated in "
|
||||
"alphabetical order of filename"),
|
||||
tflite::Flag::CreateFlag(
|
||||
kGroundTruthLabelsFlag, &ground_truth_labels_path_,
|
||||
"Path to ground truth labels, corresponding to alphabetical ordering "
|
||||
"of ground truth images."),
|
||||
tflite::Flag::CreateFlag(
|
||||
kBlacklistFilePathFlag, &blacklist_file_path_,
|
||||
"Path to blacklist file (optional) where each line is a single "
|
||||
"integer that is "
|
||||
"equal to index number of blacklisted image."),
|
||||
tflite::Flag::CreateFlag(kOutputFilePathFlag, &output_file_path_,
|
||||
"File to output metrics proto to."),
|
||||
tflite::Flag::CreateFlag(kNumImagesFlag, &num_images_,
|
||||
"Number of examples to evaluate, pass 0 for all "
|
||||
"examples. Default: 0"),
|
||||
tflite::Flag::CreateFlag(
|
||||
kInterpreterThreadsFlag, &num_interpreter_threads_,
|
||||
"Number of interpreter threads to use for inference."),
|
||||
tflite::Flag::CreateFlag(
|
||||
kDelegateFlag, &delegate_,
|
||||
"Delegate to use for inference, if available. "
|
||||
"Must be one of {'nnapi', 'gpu', 'hexagon', 'xnnpack'}"),
|
||||
};
|
||||
tflite::Flags::Parse(argc, const_cast<const char**>(argv), flag_list);
|
||||
delegate_providers_.InitFromCmdlineArgs(argc, const_cast<const char**>(argv));
|
||||
}
|
||||
|
||||
absl::optional<EvaluationStageMetrics> ImagenetClassification::Run() {
|
||||
// Process images in filename-sorted order.
|
||||
std::vector<std::string> image_files, ground_truth_image_labels;
|
||||
if (GetSortedFileNames(StripTrailingSlashes(ground_truth_images_path_),
|
||||
&image_files) != kTfLiteOk) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
if (!ReadFileLines(ground_truth_labels_path_, &ground_truth_image_labels)) {
|
||||
TFLITE_LOG(ERROR) << "Could not read ground truth labels file";
|
||||
return absl::nullopt;
|
||||
}
|
||||
if (image_files.size() != ground_truth_image_labels.size()) {
|
||||
TFLITE_LOG(ERROR) << "Number of images and ground truth labels is not same";
|
||||
return absl::nullopt;
|
||||
}
|
||||
std::vector<ImageLabel> image_labels;
|
||||
image_labels.reserve(image_files.size());
|
||||
for (int i = 0; i < image_files.size(); i++) {
|
||||
image_labels.push_back({image_files[i], ground_truth_image_labels[i]});
|
||||
}
|
||||
|
||||
// Filter out blacklisted/unwanted images.
|
||||
if (FilterBlackListedImages(blacklist_file_path_, &image_labels) !=
|
||||
kTfLiteOk) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
if (num_images_ > 0) {
|
||||
image_labels = GetFirstN(image_labels, num_images_);
|
||||
}
|
||||
|
||||
std::vector<std::string> model_labels;
|
||||
if (!ReadFileLines(model_output_labels_path_, &model_labels)) {
|
||||
TFLITE_LOG(ERROR) << "Could not read model output labels file";
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
EvaluationStageConfig eval_config;
|
||||
eval_config.set_name("image_classification");
|
||||
auto* classification_params = eval_config.mutable_specification()
|
||||
->mutable_image_classification_params();
|
||||
auto* inference_params = classification_params->mutable_inference_params();
|
||||
inference_params->set_model_file_path(model_file_path);
|
||||
inference_params->set_num_threads(num_interpreter_threads);
|
||||
inference_params->set_delegate(ParseStringToDelegateType(delegate));
|
||||
if (!delegate.empty() &&
|
||||
inference_params->delegate() == TfliteInferenceParams::NONE) {
|
||||
TFLITE_LOG(WARN) << "Unsupported TFLite delegate: " << delegate;
|
||||
return false;
|
||||
}
|
||||
inference_params->set_model_file_path(model_file_path_);
|
||||
inference_params->set_num_threads(num_interpreter_threads_);
|
||||
inference_params->set_delegate(ParseStringToDelegateType(delegate_));
|
||||
classification_params->mutable_topk_accuracy_eval_params()->set_k(10);
|
||||
|
||||
ImageClassificationStage eval(eval_config);
|
||||
|
||||
eval.SetAllLabels(model_labels);
|
||||
if (eval.Init(&delegate_providers) != kTfLiteOk) return false;
|
||||
if (eval.Init(&delegate_providers_) != kTfLiteOk) return absl::nullopt;
|
||||
|
||||
const int step = image_labels.size() / 100;
|
||||
for (int i = 0; i < image_labels.size(); ++i) {
|
||||
if (step > 1 && i % step == 0) {
|
||||
TFLITE_LOG(INFO) << "Evaluated: " << i / step << "%";
|
||||
}
|
||||
|
||||
eval.SetInputs(image_labels[i].image, image_labels[i].label);
|
||||
if (eval.Run() != kTfLiteOk) return false;
|
||||
if (eval.Run() != kTfLiteOk) return absl::nullopt;
|
||||
}
|
||||
|
||||
const auto latest_metrics = eval.LatestMetrics();
|
||||
if (!output_file_path.empty()) {
|
||||
OutputResult(latest_metrics);
|
||||
return absl::make_optional(latest_metrics);
|
||||
}
|
||||
|
||||
void ImagenetClassification::OutputResult(
|
||||
const EvaluationStageMetrics& latest_metrics) const {
|
||||
if (!output_file_path_.empty()) {
|
||||
std::ofstream metrics_ofile;
|
||||
metrics_ofile.open(output_file_path, std::ios::out);
|
||||
metrics_ofile.open(output_file_path_, std::ios::out);
|
||||
metrics_ofile << latest_metrics.SerializeAsString();
|
||||
metrics_ofile.close();
|
||||
}
|
||||
|
||||
TFLITE_LOG(INFO) << "Num evaluation runs: " << latest_metrics.num_runs();
|
||||
const auto& metrics =
|
||||
latest_metrics.process_metrics().image_classification_metrics();
|
||||
@ -105,103 +201,11 @@ bool EvaluateModel(const std::string& model_file_path,
|
||||
TFLITE_LOG(INFO) << "Top-" << i + 1
|
||||
<< " Accuracy: " << accuracy_metrics.topk_accuracies(i);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int Main(int argc, char* argv[]) {
|
||||
// Command Line Flags.
|
||||
std::string model_file_path;
|
||||
std::string ground_truth_images_path;
|
||||
std::string ground_truth_labels_path;
|
||||
std::string model_output_labels_path;
|
||||
std::string blacklist_file_path;
|
||||
std::string output_file_path;
|
||||
std::string delegate;
|
||||
int num_images = 0;
|
||||
int num_interpreter_threads = 1;
|
||||
std::vector<tflite::Flag> flag_list = {
|
||||
tflite::Flag::CreateFlag(kModelFileFlag, &model_file_path,
|
||||
"Path to test tflite model file."),
|
||||
tflite::Flag::CreateFlag(
|
||||
kModelOutputLabelsFlag, &model_output_labels_path,
|
||||
"Path to labels that correspond to output of model."
|
||||
" E.g. in case of mobilenet, this is the path to label "
|
||||
"file where each label is in the same order as the output"
|
||||
" of the model."),
|
||||
tflite::Flag::CreateFlag(
|
||||
kGroundTruthImagesPathFlag, &ground_truth_images_path,
|
||||
"Path to ground truth images. These will be evaluated in "
|
||||
"alphabetical order of filename"),
|
||||
tflite::Flag::CreateFlag(
|
||||
kGroundTruthLabelsFlag, &ground_truth_labels_path,
|
||||
"Path to ground truth labels, corresponding to alphabetical ordering "
|
||||
"of ground truth images."),
|
||||
tflite::Flag::CreateFlag(
|
||||
kBlacklistFilePathFlag, &blacklist_file_path,
|
||||
"Path to blacklist file (optional) where each line is a single "
|
||||
"integer that is "
|
||||
"equal to index number of blacklisted image."),
|
||||
tflite::Flag::CreateFlag(kOutputFilePathFlag, &output_file_path,
|
||||
"File to output metrics proto to."),
|
||||
tflite::Flag::CreateFlag(kNumImagesFlag, &num_images,
|
||||
"Number of examples to evaluate, pass 0 for all "
|
||||
"examples. Default: 0"),
|
||||
tflite::Flag::CreateFlag(
|
||||
kInterpreterThreadsFlag, &num_interpreter_threads,
|
||||
"Number of interpreter threads to use for inference."),
|
||||
tflite::Flag::CreateFlag(kDelegateFlag, &delegate,
|
||||
"Delegate to use for inference, if available. "
|
||||
"Must be one of {'nnapi', 'gpu'}"),
|
||||
};
|
||||
tflite::Flags::Parse(&argc, const_cast<const char**>(argv), flag_list);
|
||||
DelegateProviders delegate_providers;
|
||||
delegate_providers.InitFromCmdlineArgs(&argc, const_cast<const char**>(argv));
|
||||
|
||||
// Process images in filename-sorted order.
|
||||
std::vector<std::string> image_files, ground_truth_image_labels;
|
||||
TF_LITE_ENSURE_STATUS(GetSortedFileNames(
|
||||
StripTrailingSlashes(ground_truth_images_path), &image_files));
|
||||
if (!ReadFileLines(ground_truth_labels_path, &ground_truth_image_labels)) {
|
||||
TFLITE_LOG(ERROR) << "Could not read ground truth labels file";
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
if (image_files.size() != ground_truth_image_labels.size()) {
|
||||
TFLITE_LOG(ERROR) << "Number of images and ground truth labels is not same";
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
std::vector<ImageLabel> image_labels;
|
||||
image_labels.reserve(image_files.size());
|
||||
for (int i = 0; i < image_files.size(); i++) {
|
||||
image_labels.push_back({image_files[i], ground_truth_image_labels[i]});
|
||||
}
|
||||
|
||||
// Filter out blacklisted/unwanted images.
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
FilterBlackListedImages(blacklist_file_path, &image_labels));
|
||||
if (num_images > 0) {
|
||||
image_labels = GetFirstN(image_labels, num_images);
|
||||
}
|
||||
|
||||
std::vector<std::string> model_labels;
|
||||
if (!ReadFileLines(model_output_labels_path, &model_labels)) {
|
||||
TFLITE_LOG(ERROR) << "Could not read model output labels file";
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
|
||||
if (!EvaluateModel(model_file_path, image_labels, model_labels, delegate,
|
||||
output_file_path, num_interpreter_threads,
|
||||
delegate_providers)) {
|
||||
TFLITE_LOG(ERROR) << "Could not evaluate model";
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
|
||||
return EXIT_SUCCESS;
|
||||
std::unique_ptr<TaskExecutor> CreateTaskExecutor(int* argc, char* argv[]) {
|
||||
return std::unique_ptr<TaskExecutor>(new ImagenetClassification(argc, argv));
|
||||
}
|
||||
|
||||
} // namespace evaluation
|
||||
} // namespace tflite
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
return tflite::evaluation::Main(argc, argv);
|
||||
}
|
||||
|
@ -1,4 +1,5 @@
|
||||
load("//tensorflow/lite:build_def.bzl", "tflite_copts", "tflite_linkopts")
|
||||
load("//tensorflow/lite:build_def.bzl", "tflite_copts")
|
||||
load("//tensorflow/lite/tools/evaluation/tasks:build_def.bzl", "task_linkopts")
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
@ -7,18 +8,11 @@ package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "run_eval",
|
||||
cc_library(
|
||||
name = "run_eval_lib",
|
||||
srcs = ["run_eval.cc"],
|
||||
copts = tflite_copts(),
|
||||
linkopts = tflite_linkopts() + select({
|
||||
"//tensorflow:android": [
|
||||
"-pie", # Android 5.0 and later supports only PIE
|
||||
"-lm", # some builtin ops, e.g., tanh, need -lm
|
||||
"-Wl,--rpath=/data/local/tmp/", # Hexagon delegate libraries should be in /data/local/tmp
|
||||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
linkopts = task_linkopts(),
|
||||
deps = [
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/tools:command_line_flags",
|
||||
@ -28,5 +22,17 @@ cc_binary(
|
||||
"//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto",
|
||||
"//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto",
|
||||
"//tensorflow/lite/tools/evaluation/stages:inference_profiler_stage",
|
||||
"//tensorflow/lite/tools/evaluation/tasks:task_executor",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "run_eval",
|
||||
copts = tflite_copts(),
|
||||
linkopts = task_linkopts(),
|
||||
deps = [
|
||||
":run_eval_lib",
|
||||
"//tensorflow/lite/tools/evaluation/tasks:task_executor_main",
|
||||
],
|
||||
)
|
||||
|
@ -16,12 +16,14 @@ limitations under the License.
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/tools/command_line_flags.h"
|
||||
#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h"
|
||||
#include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
|
||||
#include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
|
||||
#include "tensorflow/lite/tools/evaluation/stages/inference_profiler_stage.h"
|
||||
#include "tensorflow/lite/tools/evaluation/tasks/task_executor.h"
|
||||
#include "tensorflow/lite/tools/logging.h"
|
||||
|
||||
namespace tflite {
|
||||
@ -33,43 +35,88 @@ constexpr char kNumRunsFlag[] = "num_runs";
|
||||
constexpr char kInterpreterThreadsFlag[] = "num_interpreter_threads";
|
||||
constexpr char kDelegateFlag[] = "delegate";
|
||||
|
||||
bool EvaluateModel(const std::string& model_file_path,
|
||||
const std::string& delegate, int num_runs,
|
||||
const std::string& output_file_path,
|
||||
int num_interpreter_threads,
|
||||
const DelegateProviders& delegate_providers) {
|
||||
class InferenceDiff : public TaskExecutor {
|
||||
public:
|
||||
InferenceDiff(int* argc, char* argv[]);
|
||||
~InferenceDiff() override {}
|
||||
|
||||
// If the run is successful, the latest metrics will be returned.
|
||||
absl::optional<EvaluationStageMetrics> Run() final;
|
||||
|
||||
private:
|
||||
void OutputResult(const EvaluationStageMetrics& latest_metrics) const;
|
||||
std::string model_file_path_;
|
||||
std::string output_file_path_;
|
||||
std::string delegate_;
|
||||
int num_runs_;
|
||||
int num_interpreter_threads_;
|
||||
DelegateProviders delegate_providers_;
|
||||
};
|
||||
|
||||
InferenceDiff::InferenceDiff(int* argc, char* argv[])
|
||||
: num_runs_(50), num_interpreter_threads_(1) {
|
||||
// Command Line Flags.
|
||||
std::vector<tflite::Flag> flag_list = {
|
||||
tflite::Flag::CreateFlag(kModelFileFlag, &model_file_path_,
|
||||
"Path to test tflite model file."),
|
||||
tflite::Flag::CreateFlag(kOutputFilePathFlag, &output_file_path_,
|
||||
"File to output metrics proto to."),
|
||||
tflite::Flag::CreateFlag(kNumRunsFlag, &num_runs_,
|
||||
"Number of runs of test & reference inference "
|
||||
"each. Default value: 50"),
|
||||
tflite::Flag::CreateFlag(
|
||||
kInterpreterThreadsFlag, &num_interpreter_threads_,
|
||||
"Number of interpreter threads to use for test inference."),
|
||||
tflite::Flag::CreateFlag(
|
||||
kDelegateFlag, &delegate_,
|
||||
"Delegate to use for test inference, if available. "
|
||||
"Must be one of {'nnapi', 'gpu', 'hexagon', 'xnnpack'}"),
|
||||
};
|
||||
tflite::Flags::Parse(argc, const_cast<const char**>(argv), flag_list);
|
||||
delegate_providers_.InitFromCmdlineArgs(argc, const_cast<const char**>(argv));
|
||||
}
|
||||
|
||||
absl::optional<EvaluationStageMetrics> InferenceDiff::Run() {
|
||||
// Initialize evaluation stage.
|
||||
EvaluationStageConfig eval_config;
|
||||
eval_config.set_name("inference_profiling");
|
||||
auto* inference_params =
|
||||
eval_config.mutable_specification()->mutable_tflite_inference_params();
|
||||
inference_params->set_model_file_path(model_file_path);
|
||||
inference_params->set_num_threads(num_interpreter_threads);
|
||||
inference_params->set_model_file_path(model_file_path_);
|
||||
inference_params->set_num_threads(num_interpreter_threads_);
|
||||
// This ensures that latency measurement isn't hampered by the time spent in
|
||||
// generating random data.
|
||||
inference_params->set_invocations_per_run(3);
|
||||
inference_params->set_delegate(ParseStringToDelegateType(delegate));
|
||||
if (!delegate.empty() &&
|
||||
inference_params->set_delegate(ParseStringToDelegateType(delegate_));
|
||||
if (!delegate_.empty() &&
|
||||
inference_params->delegate() == TfliteInferenceParams::NONE) {
|
||||
TFLITE_LOG(WARN) << "Unsupported TFLite delegate: " << delegate;
|
||||
return false;
|
||||
TFLITE_LOG(WARN) << "Unsupported TFLite delegate: " << delegate_;
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
InferenceProfilerStage eval(eval_config);
|
||||
if (eval.Init(&delegate_providers) != kTfLiteOk) return false;
|
||||
if (eval.Init(&delegate_providers_) != kTfLiteOk) return absl::nullopt;
|
||||
|
||||
// Run inference & check diff for specified number of runs.
|
||||
for (int i = 0; i < num_runs; ++i) {
|
||||
if (eval.Run() != kTfLiteOk) return false;
|
||||
for (int i = 0; i < num_runs_; ++i) {
|
||||
if (eval.Run() != kTfLiteOk) return absl::nullopt;
|
||||
}
|
||||
|
||||
// Output latency & diff metrics.
|
||||
const auto latest_metrics = eval.LatestMetrics();
|
||||
if (!output_file_path.empty()) {
|
||||
OutputResult(latest_metrics);
|
||||
return absl::make_optional(latest_metrics);
|
||||
}
|
||||
|
||||
void InferenceDiff::OutputResult(
|
||||
const EvaluationStageMetrics& latest_metrics) const {
|
||||
// Output latency & diff metrics.
|
||||
if (!output_file_path_.empty()) {
|
||||
std::ofstream metrics_ofile;
|
||||
metrics_ofile.open(output_file_path, std::ios::out);
|
||||
metrics_ofile.open(output_file_path_, std::ios::out);
|
||||
metrics_ofile << latest_metrics.SerializeAsString();
|
||||
metrics_ofile.close();
|
||||
}
|
||||
|
||||
TFLITE_LOG(INFO) << "Num evaluation runs: " << latest_metrics.num_runs();
|
||||
const auto& metrics =
|
||||
latest_metrics.process_metrics().inference_profiler_metrics();
|
||||
@ -88,48 +135,11 @@ bool EvaluateModel(const std::string& model_file_path,
|
||||
<< "]: avg_error=" << error.avg_value()
|
||||
<< ", std_dev=" << error.std_deviation();
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
int Main(int argc, char* argv[]) {
|
||||
// Command Line Flags.
|
||||
std::string model_file_path;
|
||||
std::string output_file_path;
|
||||
std::string delegate;
|
||||
int num_runs = 50;
|
||||
int num_interpreter_threads = 1;
|
||||
std::vector<tflite::Flag> flag_list = {
|
||||
tflite::Flag::CreateFlag(kModelFileFlag, &model_file_path,
|
||||
"Path to test tflite model file."),
|
||||
tflite::Flag::CreateFlag(kOutputFilePathFlag, &output_file_path,
|
||||
"File to output metrics proto to."),
|
||||
tflite::Flag::CreateFlag(kNumRunsFlag, &num_runs,
|
||||
"Number of runs of test & reference inference "
|
||||
"each. Default value: 50"),
|
||||
tflite::Flag::CreateFlag(
|
||||
kInterpreterThreadsFlag, &num_interpreter_threads,
|
||||
"Number of interpreter threads to use for test inference."),
|
||||
tflite::Flag::CreateFlag(
|
||||
kDelegateFlag, &delegate,
|
||||
"Delegate to use for test inference, if available. "
|
||||
"Must be one of {'nnapi', 'gpu', 'hexagon'}"),
|
||||
};
|
||||
tflite::Flags::Parse(&argc, const_cast<const char**>(argv), flag_list);
|
||||
|
||||
DelegateProviders delegate_providers;
|
||||
delegate_providers.InitFromCmdlineArgs(&argc, const_cast<const char**>(argv));
|
||||
if (!EvaluateModel(model_file_path, delegate, num_runs, output_file_path,
|
||||
num_interpreter_threads, delegate_providers)) {
|
||||
TFLITE_LOG(ERROR) << "Could not evaluate model!";
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
|
||||
return EXIT_SUCCESS;
|
||||
std::unique_ptr<TaskExecutor> CreateTaskExecutor(int* argc, char* argv[]) {
|
||||
return std::unique_ptr<TaskExecutor>(new InferenceDiff(argc, argv));
|
||||
}
|
||||
|
||||
} // namespace evaluation
|
||||
} // namespace tflite
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
return tflite::evaluation::Main(argc, argv);
|
||||
}
|
||||
|
38
tensorflow/lite/tools/evaluation/tasks/task_executor.h
Normal file
38
tensorflow/lite/tools/evaluation/tasks/task_executor.h
Normal file
@ -0,0 +1,38 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_TOOLS_EVALUATION_TASKS_TASK_EXECUTOR_H_
|
||||
#define TENSORFLOW_LITE_TOOLS_EVALUATION_TASKS_TASK_EXECUTOR_H_
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace evaluation {
|
||||
// A common task execution API to avoid boilerpolate code in defining the main
|
||||
// function.
|
||||
class TaskExecutor {
|
||||
public:
|
||||
virtual ~TaskExecutor() {}
|
||||
// If the run is successful, the latest metrics will be returned.
|
||||
virtual absl::optional<EvaluationStageMetrics> Run() = 0;
|
||||
};
|
||||
|
||||
// Just a declaration. In order to avoid the boilerpolate main-function code,
|
||||
// every evaluation task should define this function.
|
||||
std::unique_ptr<TaskExecutor> CreateTaskExecutor(int* argc, char* argv[]);
|
||||
} // namespace evaluation
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_TOOLS_EVALUATION_TASKS_TASK_EXECUTOR_H_
|
32
tensorflow/lite/tools/evaluation/tasks/task_executor_main.cc
Normal file
32
tensorflow/lite/tools/evaluation/tasks/task_executor_main.cc
Normal file
@ -0,0 +1,32 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/lite/tools/evaluation/tasks/task_executor.h"
|
||||
#include "tensorflow/lite/tools/logging.h"
|
||||
|
||||
// This could serve as the main function for all eval tools.
|
||||
int main(int argc, char* argv[]) {
|
||||
auto task_executor = tflite::evaluation::CreateTaskExecutor(&argc, argv);
|
||||
if (task_executor == nullptr) {
|
||||
TFLITE_LOG(ERROR) << "Could not create the task evaluation!";
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
const auto metrics = task_executor->Run();
|
||||
if (!metrics.has_value()) {
|
||||
TFLITE_LOG(ERROR) << "Could not run the task evaluation!";
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
return EXIT_SUCCESS;
|
||||
}
|
Loading…
Reference in New Issue
Block a user