Use the same task main stub for the coco object detection eval task.
PiperOrigin-RevId: 307503491 Change-Id: I03b9a42c09fc155e3883cdd4b852bd61911dc4a8
This commit is contained in:
parent
e7046ce5d6
commit
b422faa5e3
tensorflow/lite/tools/evaluation
@ -157,7 +157,7 @@ EvaluationStageMetrics ObjectDetectionStage::LatestMetrics() {
|
||||
}
|
||||
|
||||
TfLiteStatus PopulateGroundTruth(
|
||||
const std::string& grouth_truth_pbtxt_file,
|
||||
const std::string& grouth_truth_proto_file,
|
||||
absl::flat_hash_map<std::string, ObjectDetectionResult>*
|
||||
ground_truth_mapping) {
|
||||
if (ground_truth_mapping == nullptr) {
|
||||
@ -166,7 +166,7 @@ TfLiteStatus PopulateGroundTruth(
|
||||
ground_truth_mapping->clear();
|
||||
|
||||
// Read the ground truth dump.
|
||||
std::ifstream t(grouth_truth_pbtxt_file);
|
||||
std::ifstream t(grouth_truth_proto_file);
|
||||
std::string proto_str((std::istreambuf_iterator<char>(t)),
|
||||
std::istreambuf_iterator<char>());
|
||||
ObjectDetectionGroundTruth ground_truth_proto;
|
||||
|
@ -97,7 +97,7 @@ class ObjectDetectionStage : public EvaluationStage {
|
||||
// preprocess_coco_minival.py script in evaluation/tasks/coco_object_detection.
|
||||
// Useful for wrappers/scripts that use ObjectDetectionStage.
|
||||
TfLiteStatus PopulateGroundTruth(
|
||||
const std::string& grouth_truth_pbtxt_file,
|
||||
const std::string& grouth_truth_proto_file,
|
||||
absl::flat_hash_map<std::string, ObjectDetectionResult>*
|
||||
ground_truth_mapping);
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
load("//tensorflow/lite:build_def.bzl", "tflite_copts", "tflite_linkopts")
|
||||
load("//tensorflow/lite:build_def.bzl", "tflite_copts")
|
||||
load("//tensorflow/lite/tools/evaluation/tasks:build_def.bzl", "task_linkopts")
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
@ -16,18 +17,11 @@ py_binary(
|
||||
deps = ["//tensorflow/lite/tools/evaluation/proto:evaluation_stages_py"],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "run_eval",
|
||||
cc_library(
|
||||
name = "run_eval_lib",
|
||||
srcs = ["run_eval.cc"],
|
||||
copts = tflite_copts(),
|
||||
linkopts = tflite_linkopts() + select({
|
||||
"//tensorflow:android": [
|
||||
"-pie", # Android 5.0 and later supports only PIE
|
||||
"-lm", # some builtin ops, e.g., tanh, need -lm
|
||||
"-Wl,--rpath=/data/local/tmp/", # Hexagon delegate libraries should be in /data/local/tmp
|
||||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
linkopts = task_linkopts(),
|
||||
deps = [
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/tools:command_line_flags",
|
||||
@ -38,6 +32,18 @@ cc_binary(
|
||||
"//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto",
|
||||
"//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto",
|
||||
"//tensorflow/lite/tools/evaluation/stages:object_detection_stage",
|
||||
"//tensorflow/lite/tools/evaluation/tasks:task_executor",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "run_eval",
|
||||
copts = tflite_copts(),
|
||||
linkopts = task_linkopts(),
|
||||
deps = [
|
||||
":run_eval_lib",
|
||||
"//tensorflow/lite/tools/evaluation/tasks:task_executor_main",
|
||||
],
|
||||
)
|
||||
|
@ -18,12 +18,14 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/tools/command_line_flags.h"
|
||||
#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h"
|
||||
#include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
|
||||
#include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
|
||||
#include "tensorflow/lite/tools/evaluation/stages/object_detection_stage.h"
|
||||
#include "tensorflow/lite/tools/evaluation/tasks/task_executor.h"
|
||||
#include "tensorflow/lite/tools/evaluation/utils.h"
|
||||
#include "tensorflow/lite/tools/logging.h"
|
||||
|
||||
@ -45,37 +47,102 @@ std::string GetNameFromPath(const std::string& str) {
|
||||
return str.substr(pos + 1);
|
||||
}
|
||||
|
||||
bool EvaluateModel(const std::string& model_file_path,
|
||||
const std::vector<std::string>& model_labels,
|
||||
const std::vector<std::string>& image_paths,
|
||||
const std::string& ground_truth_proto_file,
|
||||
std::string delegate, std::string output_file_path,
|
||||
int num_interpreter_threads, bool debug_mode,
|
||||
const DelegateProviders& delegate_providers) {
|
||||
class CocoObjectDetection : public TaskExecutor {
|
||||
public:
|
||||
CocoObjectDetection(int* argc, char* argv[]);
|
||||
~CocoObjectDetection() override {}
|
||||
|
||||
// If the run is successful, the latest metrics will be returned.
|
||||
absl::optional<EvaluationStageMetrics> Run() final;
|
||||
|
||||
private:
|
||||
void OutputResult(const EvaluationStageMetrics& latest_metrics) const;
|
||||
std::string model_file_path_;
|
||||
std::string model_output_labels_path_;
|
||||
std::string ground_truth_images_path_;
|
||||
std::string ground_truth_proto_file_;
|
||||
std::string output_file_path_;
|
||||
bool debug_mode_;
|
||||
std::string delegate_;
|
||||
int num_interpreter_threads_;
|
||||
DelegateProviders delegate_providers_;
|
||||
};
|
||||
|
||||
CocoObjectDetection::CocoObjectDetection(int* argc, char* argv[])
|
||||
: debug_mode_(false), num_interpreter_threads_(1) {
|
||||
std::vector<tflite::Flag> flag_list = {
|
||||
tflite::Flag::CreateFlag(kModelFileFlag, &model_file_path_,
|
||||
"Path to test tflite model file."),
|
||||
tflite::Flag::CreateFlag(
|
||||
kModelOutputLabelsFlag, &model_output_labels_path_,
|
||||
"Path to labels that correspond to output of model."
|
||||
" E.g. in case of COCO-trained SSD model, this is the path to file "
|
||||
"where each line contains a class detected by the model in correct "
|
||||
"order, starting from background."),
|
||||
tflite::Flag::CreateFlag(
|
||||
kGroundTruthImagesPathFlag, &ground_truth_images_path_,
|
||||
"Path to ground truth images. These will be evaluated in "
|
||||
"alphabetical order of filenames"),
|
||||
tflite::Flag::CreateFlag(kGroundTruthProtoFileFlag,
|
||||
&ground_truth_proto_file_,
|
||||
"Path to file containing "
|
||||
"tflite::evaluation::ObjectDetectionGroundTruth "
|
||||
"proto in binary serialized format. If left "
|
||||
"empty, mAP numbers are not output."),
|
||||
tflite::Flag::CreateFlag(
|
||||
kOutputFilePathFlag, &output_file_path_,
|
||||
"File to output to. Contains only metrics proto if debug_mode is "
|
||||
"off, and per-image predictions also otherwise."),
|
||||
tflite::Flag::CreateFlag(kDebugModeFlag, &debug_mode_,
|
||||
"Whether to enable debug mode. Per-image "
|
||||
"predictions are written to the output file "
|
||||
"along with metrics."),
|
||||
tflite::Flag::CreateFlag(
|
||||
kInterpreterThreadsFlag, &num_interpreter_threads_,
|
||||
"Number of interpreter threads to use for inference."),
|
||||
tflite::Flag::CreateFlag(
|
||||
kDelegateFlag, &delegate_,
|
||||
"Delegate to use for inference, if available. "
|
||||
"Must be one of {'nnapi', 'gpu', 'xnnpack', 'hexagon'}"),
|
||||
};
|
||||
tflite::Flags::Parse(argc, const_cast<const char**>(argv), flag_list);
|
||||
DelegateProviders delegate_providers;
|
||||
delegate_providers.InitFromCmdlineArgs(argc, const_cast<const char**>(argv));
|
||||
}
|
||||
|
||||
absl::optional<EvaluationStageMetrics> CocoObjectDetection::Run() {
|
||||
// Process images in filename-sorted order.
|
||||
std::vector<std::string> image_paths;
|
||||
if (GetSortedFileNames(StripTrailingSlashes(ground_truth_images_path_),
|
||||
&image_paths) != kTfLiteOk) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
std::vector<std::string> model_labels;
|
||||
if (!ReadFileLines(model_output_labels_path_, &model_labels)) {
|
||||
TFLITE_LOG(ERROR) << "Could not read model output labels file";
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
EvaluationStageConfig eval_config;
|
||||
eval_config.set_name("object_detection");
|
||||
auto* detection_params =
|
||||
eval_config.mutable_specification()->mutable_object_detection_params();
|
||||
auto* inference_params = detection_params->mutable_inference_params();
|
||||
inference_params->set_model_file_path(model_file_path);
|
||||
inference_params->set_num_threads(num_interpreter_threads);
|
||||
inference_params->set_delegate(ParseStringToDelegateType(delegate));
|
||||
if (!delegate.empty() &&
|
||||
inference_params->delegate() == TfliteInferenceParams::NONE) {
|
||||
TFLITE_LOG(WARN) << "Unsupported TFLite delegate: " << delegate;
|
||||
return false;
|
||||
}
|
||||
inference_params->set_model_file_path(model_file_path_);
|
||||
inference_params->set_num_threads(num_interpreter_threads_);
|
||||
inference_params->set_delegate(ParseStringToDelegateType(delegate_));
|
||||
|
||||
// Get ground truth data.
|
||||
absl::flat_hash_map<std::string, ObjectDetectionResult> ground_truth_map;
|
||||
if (!ground_truth_proto_file.empty()) {
|
||||
PopulateGroundTruth(ground_truth_proto_file, &ground_truth_map);
|
||||
if (!ground_truth_proto_file_.empty()) {
|
||||
PopulateGroundTruth(ground_truth_proto_file_, &ground_truth_map);
|
||||
}
|
||||
|
||||
ObjectDetectionStage eval(eval_config);
|
||||
|
||||
eval.SetAllLabels(model_labels);
|
||||
if (eval.Init(&delegate_providers) != kTfLiteOk) return false;
|
||||
if (eval.Init(&delegate_providers_) != kTfLiteOk) return absl::nullopt;
|
||||
|
||||
const int step = image_paths.size() / 100;
|
||||
for (int i = 0; i < image_paths.size(); ++i) {
|
||||
@ -85,9 +152,9 @@ bool EvaluateModel(const std::string& model_file_path,
|
||||
|
||||
const std::string image_name = GetNameFromPath(image_paths[i]);
|
||||
eval.SetInputs(image_paths[i], ground_truth_map[image_name]);
|
||||
if (eval.Run() != kTfLiteOk) return false;
|
||||
if (eval.Run() != kTfLiteOk) return absl::nullopt;
|
||||
|
||||
if (debug_mode) {
|
||||
if (debug_mode_) {
|
||||
ObjectDetectionResult prediction = *eval.GetLatestPrediction();
|
||||
TFLITE_LOG(INFO) << "Image: " << image_name << "\n";
|
||||
for (int i = 0; i < prediction.objects_size(); ++i) {
|
||||
@ -113,15 +180,22 @@ bool EvaluateModel(const std::string& model_file_path,
|
||||
|
||||
// Write metrics to file.
|
||||
EvaluationStageMetrics latest_metrics = eval.LatestMetrics();
|
||||
if (ground_truth_proto_file.empty()) {
|
||||
// mAP metrics are meaningless for no ground truth.
|
||||
if (ground_truth_proto_file_.empty()) {
|
||||
TFLITE_LOG(WARN) << "mAP metrics are meaningless w/o ground truth.";
|
||||
latest_metrics.mutable_process_metrics()
|
||||
->mutable_object_detection_metrics()
|
||||
->clear_average_precision_metrics();
|
||||
}
|
||||
if (!output_file_path.empty()) {
|
||||
|
||||
OutputResult(latest_metrics);
|
||||
return absl::make_optional(latest_metrics);
|
||||
}
|
||||
|
||||
void CocoObjectDetection::OutputResult(
|
||||
const EvaluationStageMetrics& latest_metrics) const {
|
||||
if (!output_file_path_.empty()) {
|
||||
std::ofstream metrics_ofile;
|
||||
metrics_ofile.open(output_file_path, std::ios::out);
|
||||
metrics_ofile.open(output_file_path_, std::ios::out);
|
||||
metrics_ofile << latest_metrics.SerializeAsString();
|
||||
metrics_ofile.close();
|
||||
}
|
||||
@ -148,81 +222,11 @@ bool EvaluateModel(const std::string& model_file_path,
|
||||
}
|
||||
TFLITE_LOG(INFO) << "Overall mAP: "
|
||||
<< precision_metrics.overall_mean_average_precision();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int Main(int argc, char* argv[]) {
|
||||
// Command Line Flags.
|
||||
std::string model_file_path;
|
||||
std::string ground_truth_images_path;
|
||||
std::string ground_truth_proto_file;
|
||||
std::string model_output_labels_path;
|
||||
std::string output_file_path;
|
||||
std::string delegate;
|
||||
int num_interpreter_threads = 1;
|
||||
bool debug_mode;
|
||||
std::vector<tflite::Flag> flag_list = {
|
||||
tflite::Flag::CreateFlag(kModelFileFlag, &model_file_path,
|
||||
"Path to test tflite model file."),
|
||||
tflite::Flag::CreateFlag(
|
||||
kModelOutputLabelsFlag, &model_output_labels_path,
|
||||
"Path to labels that correspond to output of model."
|
||||
" E.g. in case of COCO-trained SSD model, this is the path to file "
|
||||
"where each line contains a class detected by the model in correct "
|
||||
"order, starting from background."),
|
||||
tflite::Flag::CreateFlag(
|
||||
kGroundTruthImagesPathFlag, &ground_truth_images_path,
|
||||
"Path to ground truth images. These will be evaluated in "
|
||||
"alphabetical order of filenames"),
|
||||
tflite::Flag::CreateFlag(
|
||||
kGroundTruthProtoFileFlag, &ground_truth_proto_file,
|
||||
"Path to file containing "
|
||||
"tflite::evaluation::ObjectDetectionGroundTruth "
|
||||
"proto in text format. If left empty, mAP numbers are not output."),
|
||||
tflite::Flag::CreateFlag(
|
||||
kOutputFilePathFlag, &output_file_path,
|
||||
"File to output to. Contains only metrics proto if debug_mode is "
|
||||
"off, and per-image predictions also otherwise."),
|
||||
tflite::Flag::CreateFlag(kDebugModeFlag, &debug_mode,
|
||||
"Whether to enable debug mode. Per-image "
|
||||
"predictions are written to the output file "
|
||||
"along with metrics."),
|
||||
tflite::Flag::CreateFlag(
|
||||
kInterpreterThreadsFlag, &num_interpreter_threads,
|
||||
"Number of interpreter threads to use for inference."),
|
||||
tflite::Flag::CreateFlag(kDelegateFlag, &delegate,
|
||||
"Delegate to use for inference, if available. "
|
||||
"Must be one of {'nnapi', 'gpu'}"),
|
||||
};
|
||||
tflite::Flags::Parse(&argc, const_cast<const char**>(argv), flag_list);
|
||||
DelegateProviders delegate_providers;
|
||||
delegate_providers.InitFromCmdlineArgs(&argc, const_cast<const char**>(argv));
|
||||
|
||||
// Process images in filename-sorted order.
|
||||
std::vector<std::string> image_paths;
|
||||
TF_LITE_ENSURE_STATUS(GetSortedFileNames(
|
||||
StripTrailingSlashes(ground_truth_images_path), &image_paths));
|
||||
|
||||
std::vector<std::string> model_labels;
|
||||
if (!ReadFileLines(model_output_labels_path, &model_labels)) {
|
||||
TFLITE_LOG(ERROR) << "Could not read model output labels file";
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
|
||||
if (!EvaluateModel(model_file_path, model_labels, image_paths,
|
||||
ground_truth_proto_file, delegate, output_file_path,
|
||||
num_interpreter_threads, debug_mode, delegate_providers)) {
|
||||
TFLITE_LOG(ERROR) << "Could not evaluate model";
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
|
||||
return EXIT_SUCCESS;
|
||||
std::unique_ptr<TaskExecutor> CreateTaskExecutor(int* argc, char* argv[]) {
|
||||
return std::unique_ptr<TaskExecutor>(new CocoObjectDetection(argc, argv));
|
||||
}
|
||||
|
||||
} // namespace evaluation
|
||||
} // namespace tflite
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
return tflite::evaluation::Main(argc, argv);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user