Use the same task main stub for the coco object detection eval task.

PiperOrigin-RevId: 307503491
Change-Id: I03b9a42c09fc155e3883cdd4b852bd61911dc4a8
This commit is contained in:
Chao Mei 2020-04-20 16:51:26 -07:00 committed by TensorFlower Gardener
parent e7046ce5d6
commit b422faa5e3
4 changed files with 120 additions and 110 deletions
tensorflow/lite/tools/evaluation

View File

@ -157,7 +157,7 @@ EvaluationStageMetrics ObjectDetectionStage::LatestMetrics() {
}
TfLiteStatus PopulateGroundTruth(
const std::string& grouth_truth_pbtxt_file,
const std::string& grouth_truth_proto_file,
absl::flat_hash_map<std::string, ObjectDetectionResult>*
ground_truth_mapping) {
if (ground_truth_mapping == nullptr) {
@ -166,7 +166,7 @@ TfLiteStatus PopulateGroundTruth(
ground_truth_mapping->clear();
// Read the ground truth dump.
std::ifstream t(grouth_truth_pbtxt_file);
std::ifstream t(grouth_truth_proto_file);
std::string proto_str((std::istreambuf_iterator<char>(t)),
std::istreambuf_iterator<char>());
ObjectDetectionGroundTruth ground_truth_proto;

View File

@ -97,7 +97,7 @@ class ObjectDetectionStage : public EvaluationStage {
// preprocess_coco_minival.py script in evaluation/tasks/coco_object_detection.
// Useful for wrappers/scripts that use ObjectDetectionStage.
TfLiteStatus PopulateGroundTruth(
const std::string& grouth_truth_pbtxt_file,
const std::string& grouth_truth_proto_file,
absl::flat_hash_map<std::string, ObjectDetectionResult>*
ground_truth_mapping);

View File

@ -1,4 +1,5 @@
load("//tensorflow/lite:build_def.bzl", "tflite_copts", "tflite_linkopts")
load("//tensorflow/lite:build_def.bzl", "tflite_copts")
load("//tensorflow/lite/tools/evaluation/tasks:build_def.bzl", "task_linkopts")
package(
default_visibility = [
@ -16,18 +17,11 @@ py_binary(
deps = ["//tensorflow/lite/tools/evaluation/proto:evaluation_stages_py"],
)
cc_binary(
name = "run_eval",
cc_library(
name = "run_eval_lib",
srcs = ["run_eval.cc"],
copts = tflite_copts(),
linkopts = tflite_linkopts() + select({
"//tensorflow:android": [
"-pie", # Android 5.0 and later supports only PIE
"-lm", # some builtin ops, e.g., tanh, need -lm
"-Wl,--rpath=/data/local/tmp/", # Hexagon delegate libraries should be in /data/local/tmp
],
"//conditions:default": [],
}),
linkopts = task_linkopts(),
deps = [
"//tensorflow/lite/c:common",
"//tensorflow/lite/tools:command_line_flags",
@ -38,6 +32,18 @@ cc_binary(
"//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto",
"//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto",
"//tensorflow/lite/tools/evaluation/stages:object_detection_stage",
"//tensorflow/lite/tools/evaluation/tasks:task_executor",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/types:optional",
],
)
cc_binary(
name = "run_eval",
copts = tflite_copts(),
linkopts = task_linkopts(),
deps = [
":run_eval_lib",
"//tensorflow/lite/tools/evaluation/tasks:task_executor_main",
],
)

View File

@ -18,12 +18,14 @@ limitations under the License.
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/types/optional.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/tools/command_line_flags.h"
#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h"
#include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
#include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
#include "tensorflow/lite/tools/evaluation/stages/object_detection_stage.h"
#include "tensorflow/lite/tools/evaluation/tasks/task_executor.h"
#include "tensorflow/lite/tools/evaluation/utils.h"
#include "tensorflow/lite/tools/logging.h"
@ -45,37 +47,102 @@ std::string GetNameFromPath(const std::string& str) {
return str.substr(pos + 1);
}
bool EvaluateModel(const std::string& model_file_path,
const std::vector<std::string>& model_labels,
const std::vector<std::string>& image_paths,
const std::string& ground_truth_proto_file,
std::string delegate, std::string output_file_path,
int num_interpreter_threads, bool debug_mode,
const DelegateProviders& delegate_providers) {
class CocoObjectDetection : public TaskExecutor {
public:
CocoObjectDetection(int* argc, char* argv[]);
~CocoObjectDetection() override {}
// If the run is successful, the latest metrics will be returned.
absl::optional<EvaluationStageMetrics> Run() final;
private:
void OutputResult(const EvaluationStageMetrics& latest_metrics) const;
std::string model_file_path_;
std::string model_output_labels_path_;
std::string ground_truth_images_path_;
std::string ground_truth_proto_file_;
std::string output_file_path_;
bool debug_mode_;
std::string delegate_;
int num_interpreter_threads_;
DelegateProviders delegate_providers_;
};
CocoObjectDetection::CocoObjectDetection(int* argc, char* argv[])
: debug_mode_(false), num_interpreter_threads_(1) {
std::vector<tflite::Flag> flag_list = {
tflite::Flag::CreateFlag(kModelFileFlag, &model_file_path_,
"Path to test tflite model file."),
tflite::Flag::CreateFlag(
kModelOutputLabelsFlag, &model_output_labels_path_,
"Path to labels that correspond to output of model."
" E.g. in case of COCO-trained SSD model, this is the path to file "
"where each line contains a class detected by the model in correct "
"order, starting from background."),
tflite::Flag::CreateFlag(
kGroundTruthImagesPathFlag, &ground_truth_images_path_,
"Path to ground truth images. These will be evaluated in "
"alphabetical order of filenames"),
tflite::Flag::CreateFlag(kGroundTruthProtoFileFlag,
&ground_truth_proto_file_,
"Path to file containing "
"tflite::evaluation::ObjectDetectionGroundTruth "
"proto in binary serialized format. If left "
"empty, mAP numbers are not output."),
tflite::Flag::CreateFlag(
kOutputFilePathFlag, &output_file_path_,
"File to output to. Contains only metrics proto if debug_mode is "
"off, and per-image predictions also otherwise."),
tflite::Flag::CreateFlag(kDebugModeFlag, &debug_mode_,
"Whether to enable debug mode. Per-image "
"predictions are written to the output file "
"along with metrics."),
tflite::Flag::CreateFlag(
kInterpreterThreadsFlag, &num_interpreter_threads_,
"Number of interpreter threads to use for inference."),
tflite::Flag::CreateFlag(
kDelegateFlag, &delegate_,
"Delegate to use for inference, if available. "
"Must be one of {'nnapi', 'gpu', 'xnnpack', 'hexagon'}"),
};
tflite::Flags::Parse(argc, const_cast<const char**>(argv), flag_list);
DelegateProviders delegate_providers;
delegate_providers.InitFromCmdlineArgs(argc, const_cast<const char**>(argv));
}
absl::optional<EvaluationStageMetrics> CocoObjectDetection::Run() {
// Process images in filename-sorted order.
std::vector<std::string> image_paths;
if (GetSortedFileNames(StripTrailingSlashes(ground_truth_images_path_),
&image_paths) != kTfLiteOk) {
return absl::nullopt;
}
std::vector<std::string> model_labels;
if (!ReadFileLines(model_output_labels_path_, &model_labels)) {
TFLITE_LOG(ERROR) << "Could not read model output labels file";
return absl::nullopt;
}
EvaluationStageConfig eval_config;
eval_config.set_name("object_detection");
auto* detection_params =
eval_config.mutable_specification()->mutable_object_detection_params();
auto* inference_params = detection_params->mutable_inference_params();
inference_params->set_model_file_path(model_file_path);
inference_params->set_num_threads(num_interpreter_threads);
inference_params->set_delegate(ParseStringToDelegateType(delegate));
if (!delegate.empty() &&
inference_params->delegate() == TfliteInferenceParams::NONE) {
TFLITE_LOG(WARN) << "Unsupported TFLite delegate: " << delegate;
return false;
}
inference_params->set_model_file_path(model_file_path_);
inference_params->set_num_threads(num_interpreter_threads_);
inference_params->set_delegate(ParseStringToDelegateType(delegate_));
// Get ground truth data.
absl::flat_hash_map<std::string, ObjectDetectionResult> ground_truth_map;
if (!ground_truth_proto_file.empty()) {
PopulateGroundTruth(ground_truth_proto_file, &ground_truth_map);
if (!ground_truth_proto_file_.empty()) {
PopulateGroundTruth(ground_truth_proto_file_, &ground_truth_map);
}
ObjectDetectionStage eval(eval_config);
eval.SetAllLabels(model_labels);
if (eval.Init(&delegate_providers) != kTfLiteOk) return false;
if (eval.Init(&delegate_providers_) != kTfLiteOk) return absl::nullopt;
const int step = image_paths.size() / 100;
for (int i = 0; i < image_paths.size(); ++i) {
@ -85,9 +152,9 @@ bool EvaluateModel(const std::string& model_file_path,
const std::string image_name = GetNameFromPath(image_paths[i]);
eval.SetInputs(image_paths[i], ground_truth_map[image_name]);
if (eval.Run() != kTfLiteOk) return false;
if (eval.Run() != kTfLiteOk) return absl::nullopt;
if (debug_mode) {
if (debug_mode_) {
ObjectDetectionResult prediction = *eval.GetLatestPrediction();
TFLITE_LOG(INFO) << "Image: " << image_name << "\n";
for (int i = 0; i < prediction.objects_size(); ++i) {
@ -113,15 +180,22 @@ bool EvaluateModel(const std::string& model_file_path,
// Write metrics to file.
EvaluationStageMetrics latest_metrics = eval.LatestMetrics();
if (ground_truth_proto_file.empty()) {
// mAP metrics are meaningless for no ground truth.
if (ground_truth_proto_file_.empty()) {
TFLITE_LOG(WARN) << "mAP metrics are meaningless w/o ground truth.";
latest_metrics.mutable_process_metrics()
->mutable_object_detection_metrics()
->clear_average_precision_metrics();
}
if (!output_file_path.empty()) {
OutputResult(latest_metrics);
return absl::make_optional(latest_metrics);
}
void CocoObjectDetection::OutputResult(
const EvaluationStageMetrics& latest_metrics) const {
if (!output_file_path_.empty()) {
std::ofstream metrics_ofile;
metrics_ofile.open(output_file_path, std::ios::out);
metrics_ofile.open(output_file_path_, std::ios::out);
metrics_ofile << latest_metrics.SerializeAsString();
metrics_ofile.close();
}
@ -148,81 +222,11 @@ bool EvaluateModel(const std::string& model_file_path,
}
TFLITE_LOG(INFO) << "Overall mAP: "
<< precision_metrics.overall_mean_average_precision();
return true;
}
int Main(int argc, char* argv[]) {
// Command Line Flags.
std::string model_file_path;
std::string ground_truth_images_path;
std::string ground_truth_proto_file;
std::string model_output_labels_path;
std::string output_file_path;
std::string delegate;
int num_interpreter_threads = 1;
bool debug_mode;
std::vector<tflite::Flag> flag_list = {
tflite::Flag::CreateFlag(kModelFileFlag, &model_file_path,
"Path to test tflite model file."),
tflite::Flag::CreateFlag(
kModelOutputLabelsFlag, &model_output_labels_path,
"Path to labels that correspond to output of model."
" E.g. in case of COCO-trained SSD model, this is the path to file "
"where each line contains a class detected by the model in correct "
"order, starting from background."),
tflite::Flag::CreateFlag(
kGroundTruthImagesPathFlag, &ground_truth_images_path,
"Path to ground truth images. These will be evaluated in "
"alphabetical order of filenames"),
tflite::Flag::CreateFlag(
kGroundTruthProtoFileFlag, &ground_truth_proto_file,
"Path to file containing "
"tflite::evaluation::ObjectDetectionGroundTruth "
"proto in text format. If left empty, mAP numbers are not output."),
tflite::Flag::CreateFlag(
kOutputFilePathFlag, &output_file_path,
"File to output to. Contains only metrics proto if debug_mode is "
"off, and per-image predictions also otherwise."),
tflite::Flag::CreateFlag(kDebugModeFlag, &debug_mode,
"Whether to enable debug mode. Per-image "
"predictions are written to the output file "
"along with metrics."),
tflite::Flag::CreateFlag(
kInterpreterThreadsFlag, &num_interpreter_threads,
"Number of interpreter threads to use for inference."),
tflite::Flag::CreateFlag(kDelegateFlag, &delegate,
"Delegate to use for inference, if available. "
"Must be one of {'nnapi', 'gpu'}"),
};
tflite::Flags::Parse(&argc, const_cast<const char**>(argv), flag_list);
DelegateProviders delegate_providers;
delegate_providers.InitFromCmdlineArgs(&argc, const_cast<const char**>(argv));
// Process images in filename-sorted order.
std::vector<std::string> image_paths;
TF_LITE_ENSURE_STATUS(GetSortedFileNames(
StripTrailingSlashes(ground_truth_images_path), &image_paths));
std::vector<std::string> model_labels;
if (!ReadFileLines(model_output_labels_path, &model_labels)) {
TFLITE_LOG(ERROR) << "Could not read model output labels file";
return EXIT_FAILURE;
}
if (!EvaluateModel(model_file_path, model_labels, image_paths,
ground_truth_proto_file, delegate, output_file_path,
num_interpreter_threads, debug_mode, delegate_providers)) {
TFLITE_LOG(ERROR) << "Could not evaluate model";
return EXIT_FAILURE;
}
return EXIT_SUCCESS;
std::unique_ptr<TaskExecutor> CreateTaskExecutor(int* argc, char* argv[]) {
return std::unique_ptr<TaskExecutor>(new CocoObjectDetection(argc, argv));
}
} // namespace evaluation
} // namespace tflite
int main(int argc, char* argv[]) {
return tflite::evaluation::Main(argc, argv);
}