Refactor the inference_diff and imagenet-accuracy evaluation task to become a utility library, and create a common main stub function for these evaluation tasks.
PiperOrigin-RevId: 307327740 Change-Id: Id1c081fa3ad1796c6ff60441f9189a985148fb09
This commit is contained in:
parent
bbaaccaab2
commit
a8c7cc2079
32
tensorflow/lite/tools/evaluation/tasks/BUILD
Normal file
32
tensorflow/lite/tools/evaluation/tasks/BUILD
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
load("//tensorflow/lite:build_def.bzl", "tflite_copts")
|
||||||
|
load("//tensorflow/lite/tools/evaluation/tasks:build_def.bzl", "task_linkopts")
|
||||||
|
|
||||||
|
package(
|
||||||
|
default_visibility = [
|
||||||
|
"//visibility:public",
|
||||||
|
],
|
||||||
|
licenses = ["notice"], # Apache 2.0
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "task_executor",
|
||||||
|
hdrs = ["task_executor.h"],
|
||||||
|
copts = tflite_copts(),
|
||||||
|
linkopts = task_linkopts(),
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto",
|
||||||
|
"@com_google_absl//absl/types:optional",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "task_executor_main",
|
||||||
|
srcs = ["task_executor_main.cc"],
|
||||||
|
copts = tflite_copts(),
|
||||||
|
linkopts = task_linkopts(),
|
||||||
|
deps = [
|
||||||
|
":task_executor",
|
||||||
|
"//tensorflow/lite/tools:logging",
|
||||||
|
"@com_google_absl//absl/types:optional",
|
||||||
|
],
|
||||||
|
)
|
14
tensorflow/lite/tools/evaluation/tasks/build_def.bzl
Normal file
14
tensorflow/lite/tools/evaluation/tasks/build_def.bzl
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
"""Common BUILD-related definitions across different tasks"""
|
||||||
|
|
||||||
|
load("//tensorflow/lite:build_def.bzl", "tflite_linkopts")
|
||||||
|
|
||||||
|
def task_linkopts():
|
||||||
|
return tflite_linkopts() + select({
|
||||||
|
"//tensorflow:android": [
|
||||||
|
"-pie", # Android 5.0 and later supports only PIE
|
||||||
|
"-lm", # some builtin ops, e.g., tanh, need -lm
|
||||||
|
# Hexagon delegate libraries should be in /data/local/tmp
|
||||||
|
"-Wl,--rpath=/data/local/tmp/",
|
||||||
|
],
|
||||||
|
"//conditions:default": [],
|
||||||
|
})
|
@ -1,4 +1,5 @@
|
|||||||
load("//tensorflow/lite:build_def.bzl", "tflite_copts", "tflite_linkopts")
|
load("//tensorflow/lite:build_def.bzl", "tflite_copts")
|
||||||
|
load("//tensorflow/lite/tools/evaluation/tasks:build_def.bzl", "task_linkopts")
|
||||||
|
|
||||||
package(
|
package(
|
||||||
default_visibility = [
|
default_visibility = [
|
||||||
@ -7,20 +8,11 @@ package(
|
|||||||
licenses = ["notice"], # Apache 2.0
|
licenses = ["notice"], # Apache 2.0
|
||||||
)
|
)
|
||||||
|
|
||||||
common_linkopts = tflite_linkopts() + select({
|
cc_library(
|
||||||
"//tensorflow:android": [
|
name = "run_eval_lib",
|
||||||
"-pie", # Android 5.0 and later supports only PIE
|
|
||||||
"-lm", # some builtin ops, e.g., tanh, need -lm
|
|
||||||
"-Wl,--rpath=/data/local/tmp/", # Hexagon delegate libraries should be in /data/local/tmp
|
|
||||||
],
|
|
||||||
"//conditions:default": [],
|
|
||||||
})
|
|
||||||
|
|
||||||
cc_binary(
|
|
||||||
name = "run_eval",
|
|
||||||
srcs = ["run_eval.cc"],
|
srcs = ["run_eval.cc"],
|
||||||
copts = tflite_copts(),
|
copts = tflite_copts(),
|
||||||
linkopts = common_linkopts,
|
linkopts = task_linkopts(),
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/lite/c:common",
|
"//tensorflow/lite/c:common",
|
||||||
"//tensorflow/lite/tools:command_line_flags",
|
"//tensorflow/lite/tools:command_line_flags",
|
||||||
@ -31,5 +23,17 @@ cc_binary(
|
|||||||
"//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto",
|
"//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto",
|
||||||
"//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto",
|
"//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto",
|
||||||
"//tensorflow/lite/tools/evaluation/stages:image_classification_stage",
|
"//tensorflow/lite/tools/evaluation/stages:image_classification_stage",
|
||||||
|
"//tensorflow/lite/tools/evaluation/tasks:task_executor",
|
||||||
|
"@com_google_absl//absl/types:optional",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_binary(
|
||||||
|
name = "run_eval",
|
||||||
|
copts = tflite_copts(),
|
||||||
|
linkopts = task_linkopts(),
|
||||||
|
deps = [
|
||||||
|
":run_eval_lib",
|
||||||
|
"//tensorflow/lite/tools/evaluation/tasks:task_executor_main",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -17,12 +17,14 @@ limitations under the License.
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/types/optional.h"
|
||||||
#include "tensorflow/lite/c/common.h"
|
#include "tensorflow/lite/c/common.h"
|
||||||
#include "tensorflow/lite/tools/command_line_flags.h"
|
#include "tensorflow/lite/tools/command_line_flags.h"
|
||||||
#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h"
|
#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h"
|
||||||
#include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
|
#include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
|
||||||
#include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
|
#include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
|
||||||
#include "tensorflow/lite/tools/evaluation/stages/image_classification_stage.h"
|
#include "tensorflow/lite/tools/evaluation/stages/image_classification_stage.h"
|
||||||
|
#include "tensorflow/lite/tools/evaluation/tasks/task_executor.h"
|
||||||
#include "tensorflow/lite/tools/evaluation/utils.h"
|
#include "tensorflow/lite/tools/evaluation/utils.h"
|
||||||
#include "tensorflow/lite/tools/logging.h"
|
#include "tensorflow/lite/tools/logging.h"
|
||||||
|
|
||||||
@ -46,49 +48,143 @@ std::vector<T> GetFirstN(const std::vector<T>& v, int n) {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool EvaluateModel(const std::string& model_file_path,
|
class ImagenetClassification : public TaskExecutor {
|
||||||
const std::vector<ImageLabel>& image_labels,
|
public:
|
||||||
const std::vector<std::string>& model_labels,
|
ImagenetClassification(int* argc, char* argv[]);
|
||||||
std::string delegate, std::string output_file_path,
|
~ImagenetClassification() override {}
|
||||||
int num_interpreter_threads,
|
|
||||||
const DelegateProviders& delegate_providers) {
|
// If the run is successful, the latest metrics will be returned.
|
||||||
|
absl::optional<EvaluationStageMetrics> Run() final;
|
||||||
|
|
||||||
|
private:
|
||||||
|
void OutputResult(const EvaluationStageMetrics& latest_metrics) const;
|
||||||
|
std::string model_file_path_;
|
||||||
|
std::string ground_truth_images_path_;
|
||||||
|
std::string ground_truth_labels_path_;
|
||||||
|
std::string model_output_labels_path_;
|
||||||
|
std::string blacklist_file_path_;
|
||||||
|
std::string output_file_path_;
|
||||||
|
std::string delegate_;
|
||||||
|
int num_images_;
|
||||||
|
int num_interpreter_threads_;
|
||||||
|
DelegateProviders delegate_providers_;
|
||||||
|
};
|
||||||
|
|
||||||
|
ImagenetClassification::ImagenetClassification(int* argc, char* argv[])
|
||||||
|
: num_images_(0), num_interpreter_threads_(1) {
|
||||||
|
std::vector<tflite::Flag> flag_list = {
|
||||||
|
tflite::Flag::CreateFlag(kModelFileFlag, &model_file_path_,
|
||||||
|
"Path to test tflite model file."),
|
||||||
|
tflite::Flag::CreateFlag(
|
||||||
|
kModelOutputLabelsFlag, &model_output_labels_path_,
|
||||||
|
"Path to labels that correspond to output of model."
|
||||||
|
" E.g. in case of mobilenet, this is the path to label "
|
||||||
|
"file where each label is in the same order as the output"
|
||||||
|
" of the model."),
|
||||||
|
tflite::Flag::CreateFlag(
|
||||||
|
kGroundTruthImagesPathFlag, &ground_truth_images_path_,
|
||||||
|
"Path to ground truth images. These will be evaluated in "
|
||||||
|
"alphabetical order of filename"),
|
||||||
|
tflite::Flag::CreateFlag(
|
||||||
|
kGroundTruthLabelsFlag, &ground_truth_labels_path_,
|
||||||
|
"Path to ground truth labels, corresponding to alphabetical ordering "
|
||||||
|
"of ground truth images."),
|
||||||
|
tflite::Flag::CreateFlag(
|
||||||
|
kBlacklistFilePathFlag, &blacklist_file_path_,
|
||||||
|
"Path to blacklist file (optional) where each line is a single "
|
||||||
|
"integer that is "
|
||||||
|
"equal to index number of blacklisted image."),
|
||||||
|
tflite::Flag::CreateFlag(kOutputFilePathFlag, &output_file_path_,
|
||||||
|
"File to output metrics proto to."),
|
||||||
|
tflite::Flag::CreateFlag(kNumImagesFlag, &num_images_,
|
||||||
|
"Number of examples to evaluate, pass 0 for all "
|
||||||
|
"examples. Default: 0"),
|
||||||
|
tflite::Flag::CreateFlag(
|
||||||
|
kInterpreterThreadsFlag, &num_interpreter_threads_,
|
||||||
|
"Number of interpreter threads to use for inference."),
|
||||||
|
tflite::Flag::CreateFlag(
|
||||||
|
kDelegateFlag, &delegate_,
|
||||||
|
"Delegate to use for inference, if available. "
|
||||||
|
"Must be one of {'nnapi', 'gpu', 'hexagon', 'xnnpack'}"),
|
||||||
|
};
|
||||||
|
tflite::Flags::Parse(argc, const_cast<const char**>(argv), flag_list);
|
||||||
|
delegate_providers_.InitFromCmdlineArgs(argc, const_cast<const char**>(argv));
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::optional<EvaluationStageMetrics> ImagenetClassification::Run() {
|
||||||
|
// Process images in filename-sorted order.
|
||||||
|
std::vector<std::string> image_files, ground_truth_image_labels;
|
||||||
|
if (GetSortedFileNames(StripTrailingSlashes(ground_truth_images_path_),
|
||||||
|
&image_files) != kTfLiteOk) {
|
||||||
|
return absl::nullopt;
|
||||||
|
}
|
||||||
|
if (!ReadFileLines(ground_truth_labels_path_, &ground_truth_image_labels)) {
|
||||||
|
TFLITE_LOG(ERROR) << "Could not read ground truth labels file";
|
||||||
|
return absl::nullopt;
|
||||||
|
}
|
||||||
|
if (image_files.size() != ground_truth_image_labels.size()) {
|
||||||
|
TFLITE_LOG(ERROR) << "Number of images and ground truth labels is not same";
|
||||||
|
return absl::nullopt;
|
||||||
|
}
|
||||||
|
std::vector<ImageLabel> image_labels;
|
||||||
|
image_labels.reserve(image_files.size());
|
||||||
|
for (int i = 0; i < image_files.size(); i++) {
|
||||||
|
image_labels.push_back({image_files[i], ground_truth_image_labels[i]});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter out blacklisted/unwanted images.
|
||||||
|
if (FilterBlackListedImages(blacklist_file_path_, &image_labels) !=
|
||||||
|
kTfLiteOk) {
|
||||||
|
return absl::nullopt;
|
||||||
|
}
|
||||||
|
if (num_images_ > 0) {
|
||||||
|
image_labels = GetFirstN(image_labels, num_images_);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> model_labels;
|
||||||
|
if (!ReadFileLines(model_output_labels_path_, &model_labels)) {
|
||||||
|
TFLITE_LOG(ERROR) << "Could not read model output labels file";
|
||||||
|
return absl::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
EvaluationStageConfig eval_config;
|
EvaluationStageConfig eval_config;
|
||||||
eval_config.set_name("image_classification");
|
eval_config.set_name("image_classification");
|
||||||
auto* classification_params = eval_config.mutable_specification()
|
auto* classification_params = eval_config.mutable_specification()
|
||||||
->mutable_image_classification_params();
|
->mutable_image_classification_params();
|
||||||
auto* inference_params = classification_params->mutable_inference_params();
|
auto* inference_params = classification_params->mutable_inference_params();
|
||||||
inference_params->set_model_file_path(model_file_path);
|
inference_params->set_model_file_path(model_file_path_);
|
||||||
inference_params->set_num_threads(num_interpreter_threads);
|
inference_params->set_num_threads(num_interpreter_threads_);
|
||||||
inference_params->set_delegate(ParseStringToDelegateType(delegate));
|
inference_params->set_delegate(ParseStringToDelegateType(delegate_));
|
||||||
if (!delegate.empty() &&
|
|
||||||
inference_params->delegate() == TfliteInferenceParams::NONE) {
|
|
||||||
TFLITE_LOG(WARN) << "Unsupported TFLite delegate: " << delegate;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
classification_params->mutable_topk_accuracy_eval_params()->set_k(10);
|
classification_params->mutable_topk_accuracy_eval_params()->set_k(10);
|
||||||
|
|
||||||
ImageClassificationStage eval(eval_config);
|
ImageClassificationStage eval(eval_config);
|
||||||
|
|
||||||
eval.SetAllLabels(model_labels);
|
eval.SetAllLabels(model_labels);
|
||||||
if (eval.Init(&delegate_providers) != kTfLiteOk) return false;
|
if (eval.Init(&delegate_providers_) != kTfLiteOk) return absl::nullopt;
|
||||||
|
|
||||||
const int step = image_labels.size() / 100;
|
const int step = image_labels.size() / 100;
|
||||||
for (int i = 0; i < image_labels.size(); ++i) {
|
for (int i = 0; i < image_labels.size(); ++i) {
|
||||||
if (step > 1 && i % step == 0) {
|
if (step > 1 && i % step == 0) {
|
||||||
TFLITE_LOG(INFO) << "Evaluated: " << i / step << "%";
|
TFLITE_LOG(INFO) << "Evaluated: " << i / step << "%";
|
||||||
}
|
}
|
||||||
|
|
||||||
eval.SetInputs(image_labels[i].image, image_labels[i].label);
|
eval.SetInputs(image_labels[i].image, image_labels[i].label);
|
||||||
if (eval.Run() != kTfLiteOk) return false;
|
if (eval.Run() != kTfLiteOk) return absl::nullopt;
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto latest_metrics = eval.LatestMetrics();
|
const auto latest_metrics = eval.LatestMetrics();
|
||||||
if (!output_file_path.empty()) {
|
OutputResult(latest_metrics);
|
||||||
|
return absl::make_optional(latest_metrics);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ImagenetClassification::OutputResult(
|
||||||
|
const EvaluationStageMetrics& latest_metrics) const {
|
||||||
|
if (!output_file_path_.empty()) {
|
||||||
std::ofstream metrics_ofile;
|
std::ofstream metrics_ofile;
|
||||||
metrics_ofile.open(output_file_path, std::ios::out);
|
metrics_ofile.open(output_file_path_, std::ios::out);
|
||||||
metrics_ofile << latest_metrics.SerializeAsString();
|
metrics_ofile << latest_metrics.SerializeAsString();
|
||||||
metrics_ofile.close();
|
metrics_ofile.close();
|
||||||
}
|
}
|
||||||
|
|
||||||
TFLITE_LOG(INFO) << "Num evaluation runs: " << latest_metrics.num_runs();
|
TFLITE_LOG(INFO) << "Num evaluation runs: " << latest_metrics.num_runs();
|
||||||
const auto& metrics =
|
const auto& metrics =
|
||||||
latest_metrics.process_metrics().image_classification_metrics();
|
latest_metrics.process_metrics().image_classification_metrics();
|
||||||
@ -105,103 +201,11 @@ bool EvaluateModel(const std::string& model_file_path,
|
|||||||
TFLITE_LOG(INFO) << "Top-" << i + 1
|
TFLITE_LOG(INFO) << "Top-" << i + 1
|
||||||
<< " Accuracy: " << accuracy_metrics.topk_accuracies(i);
|
<< " Accuracy: " << accuracy_metrics.topk_accuracies(i);
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int Main(int argc, char* argv[]) {
|
std::unique_ptr<TaskExecutor> CreateTaskExecutor(int* argc, char* argv[]) {
|
||||||
// Command Line Flags.
|
return std::unique_ptr<TaskExecutor>(new ImagenetClassification(argc, argv));
|
||||||
std::string model_file_path;
|
|
||||||
std::string ground_truth_images_path;
|
|
||||||
std::string ground_truth_labels_path;
|
|
||||||
std::string model_output_labels_path;
|
|
||||||
std::string blacklist_file_path;
|
|
||||||
std::string output_file_path;
|
|
||||||
std::string delegate;
|
|
||||||
int num_images = 0;
|
|
||||||
int num_interpreter_threads = 1;
|
|
||||||
std::vector<tflite::Flag> flag_list = {
|
|
||||||
tflite::Flag::CreateFlag(kModelFileFlag, &model_file_path,
|
|
||||||
"Path to test tflite model file."),
|
|
||||||
tflite::Flag::CreateFlag(
|
|
||||||
kModelOutputLabelsFlag, &model_output_labels_path,
|
|
||||||
"Path to labels that correspond to output of model."
|
|
||||||
" E.g. in case of mobilenet, this is the path to label "
|
|
||||||
"file where each label is in the same order as the output"
|
|
||||||
" of the model."),
|
|
||||||
tflite::Flag::CreateFlag(
|
|
||||||
kGroundTruthImagesPathFlag, &ground_truth_images_path,
|
|
||||||
"Path to ground truth images. These will be evaluated in "
|
|
||||||
"alphabetical order of filename"),
|
|
||||||
tflite::Flag::CreateFlag(
|
|
||||||
kGroundTruthLabelsFlag, &ground_truth_labels_path,
|
|
||||||
"Path to ground truth labels, corresponding to alphabetical ordering "
|
|
||||||
"of ground truth images."),
|
|
||||||
tflite::Flag::CreateFlag(
|
|
||||||
kBlacklistFilePathFlag, &blacklist_file_path,
|
|
||||||
"Path to blacklist file (optional) where each line is a single "
|
|
||||||
"integer that is "
|
|
||||||
"equal to index number of blacklisted image."),
|
|
||||||
tflite::Flag::CreateFlag(kOutputFilePathFlag, &output_file_path,
|
|
||||||
"File to output metrics proto to."),
|
|
||||||
tflite::Flag::CreateFlag(kNumImagesFlag, &num_images,
|
|
||||||
"Number of examples to evaluate, pass 0 for all "
|
|
||||||
"examples. Default: 0"),
|
|
||||||
tflite::Flag::CreateFlag(
|
|
||||||
kInterpreterThreadsFlag, &num_interpreter_threads,
|
|
||||||
"Number of interpreter threads to use for inference."),
|
|
||||||
tflite::Flag::CreateFlag(kDelegateFlag, &delegate,
|
|
||||||
"Delegate to use for inference, if available. "
|
|
||||||
"Must be one of {'nnapi', 'gpu'}"),
|
|
||||||
};
|
|
||||||
tflite::Flags::Parse(&argc, const_cast<const char**>(argv), flag_list);
|
|
||||||
DelegateProviders delegate_providers;
|
|
||||||
delegate_providers.InitFromCmdlineArgs(&argc, const_cast<const char**>(argv));
|
|
||||||
|
|
||||||
// Process images in filename-sorted order.
|
|
||||||
std::vector<std::string> image_files, ground_truth_image_labels;
|
|
||||||
TF_LITE_ENSURE_STATUS(GetSortedFileNames(
|
|
||||||
StripTrailingSlashes(ground_truth_images_path), &image_files));
|
|
||||||
if (!ReadFileLines(ground_truth_labels_path, &ground_truth_image_labels)) {
|
|
||||||
TFLITE_LOG(ERROR) << "Could not read ground truth labels file";
|
|
||||||
return EXIT_FAILURE;
|
|
||||||
}
|
|
||||||
if (image_files.size() != ground_truth_image_labels.size()) {
|
|
||||||
TFLITE_LOG(ERROR) << "Number of images and ground truth labels is not same";
|
|
||||||
return EXIT_FAILURE;
|
|
||||||
}
|
|
||||||
std::vector<ImageLabel> image_labels;
|
|
||||||
image_labels.reserve(image_files.size());
|
|
||||||
for (int i = 0; i < image_files.size(); i++) {
|
|
||||||
image_labels.push_back({image_files[i], ground_truth_image_labels[i]});
|
|
||||||
}
|
|
||||||
|
|
||||||
// Filter out blacklisted/unwanted images.
|
|
||||||
TF_LITE_ENSURE_STATUS(
|
|
||||||
FilterBlackListedImages(blacklist_file_path, &image_labels));
|
|
||||||
if (num_images > 0) {
|
|
||||||
image_labels = GetFirstN(image_labels, num_images);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<std::string> model_labels;
|
|
||||||
if (!ReadFileLines(model_output_labels_path, &model_labels)) {
|
|
||||||
TFLITE_LOG(ERROR) << "Could not read model output labels file";
|
|
||||||
return EXIT_FAILURE;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!EvaluateModel(model_file_path, image_labels, model_labels, delegate,
|
|
||||||
output_file_path, num_interpreter_threads,
|
|
||||||
delegate_providers)) {
|
|
||||||
TFLITE_LOG(ERROR) << "Could not evaluate model";
|
|
||||||
return EXIT_FAILURE;
|
|
||||||
}
|
|
||||||
|
|
||||||
return EXIT_SUCCESS;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace evaluation
|
} // namespace evaluation
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
int main(int argc, char* argv[]) {
|
|
||||||
return tflite::evaluation::Main(argc, argv);
|
|
||||||
}
|
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
load("//tensorflow/lite:build_def.bzl", "tflite_copts", "tflite_linkopts")
|
load("//tensorflow/lite:build_def.bzl", "tflite_copts")
|
||||||
|
load("//tensorflow/lite/tools/evaluation/tasks:build_def.bzl", "task_linkopts")
|
||||||
|
|
||||||
package(
|
package(
|
||||||
default_visibility = [
|
default_visibility = [
|
||||||
@ -7,18 +8,11 @@ package(
|
|||||||
licenses = ["notice"], # Apache 2.0
|
licenses = ["notice"], # Apache 2.0
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_binary(
|
cc_library(
|
||||||
name = "run_eval",
|
name = "run_eval_lib",
|
||||||
srcs = ["run_eval.cc"],
|
srcs = ["run_eval.cc"],
|
||||||
copts = tflite_copts(),
|
copts = tflite_copts(),
|
||||||
linkopts = tflite_linkopts() + select({
|
linkopts = task_linkopts(),
|
||||||
"//tensorflow:android": [
|
|
||||||
"-pie", # Android 5.0 and later supports only PIE
|
|
||||||
"-lm", # some builtin ops, e.g., tanh, need -lm
|
|
||||||
"-Wl,--rpath=/data/local/tmp/", # Hexagon delegate libraries should be in /data/local/tmp
|
|
||||||
],
|
|
||||||
"//conditions:default": [],
|
|
||||||
}),
|
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/lite/c:common",
|
"//tensorflow/lite/c:common",
|
||||||
"//tensorflow/lite/tools:command_line_flags",
|
"//tensorflow/lite/tools:command_line_flags",
|
||||||
@ -28,5 +22,17 @@ cc_binary(
|
|||||||
"//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto",
|
"//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto",
|
||||||
"//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto",
|
"//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto",
|
||||||
"//tensorflow/lite/tools/evaluation/stages:inference_profiler_stage",
|
"//tensorflow/lite/tools/evaluation/stages:inference_profiler_stage",
|
||||||
|
"//tensorflow/lite/tools/evaluation/tasks:task_executor",
|
||||||
|
"@com_google_absl//absl/types:optional",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_binary(
|
||||||
|
name = "run_eval",
|
||||||
|
copts = tflite_copts(),
|
||||||
|
linkopts = task_linkopts(),
|
||||||
|
deps = [
|
||||||
|
":run_eval_lib",
|
||||||
|
"//tensorflow/lite/tools/evaluation/tasks:task_executor_main",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -16,12 +16,14 @@ limitations under the License.
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/types/optional.h"
|
||||||
#include "tensorflow/lite/c/common.h"
|
#include "tensorflow/lite/c/common.h"
|
||||||
#include "tensorflow/lite/tools/command_line_flags.h"
|
#include "tensorflow/lite/tools/command_line_flags.h"
|
||||||
#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h"
|
#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h"
|
||||||
#include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
|
#include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
|
||||||
#include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
|
#include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
|
||||||
#include "tensorflow/lite/tools/evaluation/stages/inference_profiler_stage.h"
|
#include "tensorflow/lite/tools/evaluation/stages/inference_profiler_stage.h"
|
||||||
|
#include "tensorflow/lite/tools/evaluation/tasks/task_executor.h"
|
||||||
#include "tensorflow/lite/tools/logging.h"
|
#include "tensorflow/lite/tools/logging.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
@ -33,43 +35,88 @@ constexpr char kNumRunsFlag[] = "num_runs";
|
|||||||
constexpr char kInterpreterThreadsFlag[] = "num_interpreter_threads";
|
constexpr char kInterpreterThreadsFlag[] = "num_interpreter_threads";
|
||||||
constexpr char kDelegateFlag[] = "delegate";
|
constexpr char kDelegateFlag[] = "delegate";
|
||||||
|
|
||||||
bool EvaluateModel(const std::string& model_file_path,
|
class InferenceDiff : public TaskExecutor {
|
||||||
const std::string& delegate, int num_runs,
|
public:
|
||||||
const std::string& output_file_path,
|
InferenceDiff(int* argc, char* argv[]);
|
||||||
int num_interpreter_threads,
|
~InferenceDiff() override {}
|
||||||
const DelegateProviders& delegate_providers) {
|
|
||||||
|
// If the run is successful, the latest metrics will be returned.
|
||||||
|
absl::optional<EvaluationStageMetrics> Run() final;
|
||||||
|
|
||||||
|
private:
|
||||||
|
void OutputResult(const EvaluationStageMetrics& latest_metrics) const;
|
||||||
|
std::string model_file_path_;
|
||||||
|
std::string output_file_path_;
|
||||||
|
std::string delegate_;
|
||||||
|
int num_runs_;
|
||||||
|
int num_interpreter_threads_;
|
||||||
|
DelegateProviders delegate_providers_;
|
||||||
|
};
|
||||||
|
|
||||||
|
InferenceDiff::InferenceDiff(int* argc, char* argv[])
|
||||||
|
: num_runs_(50), num_interpreter_threads_(1) {
|
||||||
|
// Command Line Flags.
|
||||||
|
std::vector<tflite::Flag> flag_list = {
|
||||||
|
tflite::Flag::CreateFlag(kModelFileFlag, &model_file_path_,
|
||||||
|
"Path to test tflite model file."),
|
||||||
|
tflite::Flag::CreateFlag(kOutputFilePathFlag, &output_file_path_,
|
||||||
|
"File to output metrics proto to."),
|
||||||
|
tflite::Flag::CreateFlag(kNumRunsFlag, &num_runs_,
|
||||||
|
"Number of runs of test & reference inference "
|
||||||
|
"each. Default value: 50"),
|
||||||
|
tflite::Flag::CreateFlag(
|
||||||
|
kInterpreterThreadsFlag, &num_interpreter_threads_,
|
||||||
|
"Number of interpreter threads to use for test inference."),
|
||||||
|
tflite::Flag::CreateFlag(
|
||||||
|
kDelegateFlag, &delegate_,
|
||||||
|
"Delegate to use for test inference, if available. "
|
||||||
|
"Must be one of {'nnapi', 'gpu', 'hexagon', 'xnnpack'}"),
|
||||||
|
};
|
||||||
|
tflite::Flags::Parse(argc, const_cast<const char**>(argv), flag_list);
|
||||||
|
delegate_providers_.InitFromCmdlineArgs(argc, const_cast<const char**>(argv));
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::optional<EvaluationStageMetrics> InferenceDiff::Run() {
|
||||||
// Initialize evaluation stage.
|
// Initialize evaluation stage.
|
||||||
EvaluationStageConfig eval_config;
|
EvaluationStageConfig eval_config;
|
||||||
eval_config.set_name("inference_profiling");
|
eval_config.set_name("inference_profiling");
|
||||||
auto* inference_params =
|
auto* inference_params =
|
||||||
eval_config.mutable_specification()->mutable_tflite_inference_params();
|
eval_config.mutable_specification()->mutable_tflite_inference_params();
|
||||||
inference_params->set_model_file_path(model_file_path);
|
inference_params->set_model_file_path(model_file_path_);
|
||||||
inference_params->set_num_threads(num_interpreter_threads);
|
inference_params->set_num_threads(num_interpreter_threads_);
|
||||||
// This ensures that latency measurement isn't hampered by the time spent in
|
// This ensures that latency measurement isn't hampered by the time spent in
|
||||||
// generating random data.
|
// generating random data.
|
||||||
inference_params->set_invocations_per_run(3);
|
inference_params->set_invocations_per_run(3);
|
||||||
inference_params->set_delegate(ParseStringToDelegateType(delegate));
|
inference_params->set_delegate(ParseStringToDelegateType(delegate_));
|
||||||
if (!delegate.empty() &&
|
if (!delegate_.empty() &&
|
||||||
inference_params->delegate() == TfliteInferenceParams::NONE) {
|
inference_params->delegate() == TfliteInferenceParams::NONE) {
|
||||||
TFLITE_LOG(WARN) << "Unsupported TFLite delegate: " << delegate;
|
TFLITE_LOG(WARN) << "Unsupported TFLite delegate: " << delegate_;
|
||||||
return false;
|
return absl::nullopt;
|
||||||
}
|
}
|
||||||
|
|
||||||
InferenceProfilerStage eval(eval_config);
|
InferenceProfilerStage eval(eval_config);
|
||||||
if (eval.Init(&delegate_providers) != kTfLiteOk) return false;
|
if (eval.Init(&delegate_providers_) != kTfLiteOk) return absl::nullopt;
|
||||||
|
|
||||||
// Run inference & check diff for specified number of runs.
|
// Run inference & check diff for specified number of runs.
|
||||||
for (int i = 0; i < num_runs; ++i) {
|
for (int i = 0; i < num_runs_; ++i) {
|
||||||
if (eval.Run() != kTfLiteOk) return false;
|
if (eval.Run() != kTfLiteOk) return absl::nullopt;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Output latency & diff metrics.
|
|
||||||
const auto latest_metrics = eval.LatestMetrics();
|
const auto latest_metrics = eval.LatestMetrics();
|
||||||
if (!output_file_path.empty()) {
|
OutputResult(latest_metrics);
|
||||||
|
return absl::make_optional(latest_metrics);
|
||||||
|
}
|
||||||
|
|
||||||
|
void InferenceDiff::OutputResult(
|
||||||
|
const EvaluationStageMetrics& latest_metrics) const {
|
||||||
|
// Output latency & diff metrics.
|
||||||
|
if (!output_file_path_.empty()) {
|
||||||
std::ofstream metrics_ofile;
|
std::ofstream metrics_ofile;
|
||||||
metrics_ofile.open(output_file_path, std::ios::out);
|
metrics_ofile.open(output_file_path_, std::ios::out);
|
||||||
metrics_ofile << latest_metrics.SerializeAsString();
|
metrics_ofile << latest_metrics.SerializeAsString();
|
||||||
metrics_ofile.close();
|
metrics_ofile.close();
|
||||||
}
|
}
|
||||||
|
|
||||||
TFLITE_LOG(INFO) << "Num evaluation runs: " << latest_metrics.num_runs();
|
TFLITE_LOG(INFO) << "Num evaluation runs: " << latest_metrics.num_runs();
|
||||||
const auto& metrics =
|
const auto& metrics =
|
||||||
latest_metrics.process_metrics().inference_profiler_metrics();
|
latest_metrics.process_metrics().inference_profiler_metrics();
|
||||||
@ -88,48 +135,11 @@ bool EvaluateModel(const std::string& model_file_path,
|
|||||||
<< "]: avg_error=" << error.avg_value()
|
<< "]: avg_error=" << error.avg_value()
|
||||||
<< ", std_dev=" << error.std_deviation();
|
<< ", std_dev=" << error.std_deviation();
|
||||||
}
|
}
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int Main(int argc, char* argv[]) {
|
std::unique_ptr<TaskExecutor> CreateTaskExecutor(int* argc, char* argv[]) {
|
||||||
// Command Line Flags.
|
return std::unique_ptr<TaskExecutor>(new InferenceDiff(argc, argv));
|
||||||
std::string model_file_path;
|
|
||||||
std::string output_file_path;
|
|
||||||
std::string delegate;
|
|
||||||
int num_runs = 50;
|
|
||||||
int num_interpreter_threads = 1;
|
|
||||||
std::vector<tflite::Flag> flag_list = {
|
|
||||||
tflite::Flag::CreateFlag(kModelFileFlag, &model_file_path,
|
|
||||||
"Path to test tflite model file."),
|
|
||||||
tflite::Flag::CreateFlag(kOutputFilePathFlag, &output_file_path,
|
|
||||||
"File to output metrics proto to."),
|
|
||||||
tflite::Flag::CreateFlag(kNumRunsFlag, &num_runs,
|
|
||||||
"Number of runs of test & reference inference "
|
|
||||||
"each. Default value: 50"),
|
|
||||||
tflite::Flag::CreateFlag(
|
|
||||||
kInterpreterThreadsFlag, &num_interpreter_threads,
|
|
||||||
"Number of interpreter threads to use for test inference."),
|
|
||||||
tflite::Flag::CreateFlag(
|
|
||||||
kDelegateFlag, &delegate,
|
|
||||||
"Delegate to use for test inference, if available. "
|
|
||||||
"Must be one of {'nnapi', 'gpu', 'hexagon'}"),
|
|
||||||
};
|
|
||||||
tflite::Flags::Parse(&argc, const_cast<const char**>(argv), flag_list);
|
|
||||||
|
|
||||||
DelegateProviders delegate_providers;
|
|
||||||
delegate_providers.InitFromCmdlineArgs(&argc, const_cast<const char**>(argv));
|
|
||||||
if (!EvaluateModel(model_file_path, delegate, num_runs, output_file_path,
|
|
||||||
num_interpreter_threads, delegate_providers)) {
|
|
||||||
TFLITE_LOG(ERROR) << "Could not evaluate model!";
|
|
||||||
return EXIT_FAILURE;
|
|
||||||
}
|
|
||||||
|
|
||||||
return EXIT_SUCCESS;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace evaluation
|
} // namespace evaluation
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
int main(int argc, char* argv[]) {
|
|
||||||
return tflite::evaluation::Main(argc, argv);
|
|
||||||
}
|
|
||||||
|
38
tensorflow/lite/tools/evaluation/tasks/task_executor.h
Normal file
38
tensorflow/lite/tools/evaluation/tasks/task_executor.h
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_LITE_TOOLS_EVALUATION_TASKS_TASK_EXECUTOR_H_
|
||||||
|
#define TENSORFLOW_LITE_TOOLS_EVALUATION_TASKS_TASK_EXECUTOR_H_
|
||||||
|
|
||||||
|
#include "absl/types/optional.h"
|
||||||
|
#include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace evaluation {
|
||||||
|
// A common task execution API to avoid boilerpolate code in defining the main
|
||||||
|
// function.
|
||||||
|
class TaskExecutor {
|
||||||
|
public:
|
||||||
|
virtual ~TaskExecutor() {}
|
||||||
|
// If the run is successful, the latest metrics will be returned.
|
||||||
|
virtual absl::optional<EvaluationStageMetrics> Run() = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Just a declaration. In order to avoid the boilerpolate main-function code,
|
||||||
|
// every evaluation task should define this function.
|
||||||
|
std::unique_ptr<TaskExecutor> CreateTaskExecutor(int* argc, char* argv[]);
|
||||||
|
} // namespace evaluation
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_TOOLS_EVALUATION_TASKS_TASK_EXECUTOR_H_
|
32
tensorflow/lite/tools/evaluation/tasks/task_executor_main.cc
Normal file
32
tensorflow/lite/tools/evaluation/tasks/task_executor_main.cc
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "absl/types/optional.h"
|
||||||
|
#include "tensorflow/lite/tools/evaluation/tasks/task_executor.h"
|
||||||
|
#include "tensorflow/lite/tools/logging.h"
|
||||||
|
|
||||||
|
// This could serve as the main function for all eval tools.
|
||||||
|
int main(int argc, char* argv[]) {
|
||||||
|
auto task_executor = tflite::evaluation::CreateTaskExecutor(&argc, argv);
|
||||||
|
if (task_executor == nullptr) {
|
||||||
|
TFLITE_LOG(ERROR) << "Could not create the task evaluation!";
|
||||||
|
return EXIT_FAILURE;
|
||||||
|
}
|
||||||
|
const auto metrics = task_executor->Run();
|
||||||
|
if (!metrics.has_value()) {
|
||||||
|
TFLITE_LOG(ERROR) << "Could not run the task evaluation!";
|
||||||
|
return EXIT_FAILURE;
|
||||||
|
}
|
||||||
|
return EXIT_SUCCESS;
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user