Refactors ImagenetModelEvaluator to use ImageClassificationStage.

PiperOrigin-RevId: 244374565
This commit is contained in:
Sachin Joglekar 2019-04-19 09:58:32 -07:00 committed by TensorFlower Gardener
parent c767cbf14c
commit 37b20d191a
14 changed files with 351 additions and 359 deletions

View File

@ -6,6 +6,7 @@ licenses(["notice"]) # Apache 2.0
load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite")
load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
load("//tensorflow/lite:build_def.bzl", "tflite_copts")
common_copts = ["-Wall"]
@ -96,4 +97,23 @@ cc_test(
],
)
cc_library(
name = "command_line_flags",
srcs = ["command_line_flags.cc"],
hdrs = ["command_line_flags.h"],
copts = tflite_copts(),
)
cc_test(
name = "command_line_flags_test",
srcs = ["command_line_flags_test.cc"],
copts = tflite_copts(),
visibility = ["//visibility:private"],
deps = [
":command_line_flags",
"//tensorflow/lite/testing:util",
"@com_google_googletest//:gtest",
],
)
tflite_portable_test_suite()

View File

@ -316,16 +316,10 @@ cc_library(
name = "csv_writer",
hdrs = ["csv_writer.h"],
copts = tflite_copts(),
deps = select(
{
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib",
],
"//conditions:default": [
"//tensorflow/core:lib",
],
},
),
deps = [
"//tensorflow/core:tflite_portable_logging",
"//tensorflow/lite/c:c_api_internal",
],
)
tflite_portable_test_suite()

View File

@ -19,8 +19,8 @@ limitations under the License.
#include <fstream>
#include <vector>
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/lite/c/c_api_internal.h"
namespace tensorflow {
namespace metrics {
@ -36,14 +36,17 @@ class CSVWriter {
public:
CSVWriter(const std::vector<string>& columns, std::ofstream* output_stream)
: num_columns_(columns.size()), output_stream_(output_stream) {
TF_CHECK_OK(WriteRow(columns, output_stream_));
if (WriteRow(columns, output_stream_) != kTfLiteOk) {
LOG(ERROR) << "Could not write column names to file";
}
}
template <typename T>
Status WriteRow(const std::vector<T>& values) {
TfLiteStatus WriteRow(const std::vector<T>& values) {
if (values.size() != num_columns_) {
return errors::InvalidArgument("Invalid size for row:", values.size(),
" expected: ", num_columns_);
LOG(ERROR) << "Invalid size for row:" << values.size()
<< " expected: " << num_columns_;
return kTfLiteError;
}
return WriteRow(values, output_stream_);
}
@ -54,8 +57,8 @@ class CSVWriter {
private:
template <typename T>
static Status WriteRow(const std::vector<T>& values,
std::ofstream* output_stream) {
static TfLiteStatus WriteRow(const std::vector<T>& values,
std::ofstream* output_stream) {
bool first = true;
for (const auto& v : values) {
if (!first) {
@ -67,9 +70,10 @@ class CSVWriter {
}
(*output_stream) << "\n";
if (!output_stream->good()) {
return errors::Internal("Writing to stream failed.");
LOG(ERROR) << "Writing to stream failed.";
return kTfLiteError;
}
return Status::OK();
return kTfLiteOk;
}
const size_t num_columns_;
std::ofstream* output_stream_;

View File

@ -126,58 +126,34 @@ cc_library(
hdrs = ["imagenet_model_evaluator.h"],
copts = tflite_copts(),
deps = [
":imagenet_topk_eval",
":inception_preprocessing",
"//tensorflow/core:tflite_portable_logging",
"//tensorflow/lite/c:c_api_internal",
"//tensorflow/lite/tools:command_line_flags",
"//tensorflow/lite/tools/accuracy:android_required_build_flags",
"//tensorflow/lite/tools/accuracy:eval_pipeline",
"//tensorflow/lite/tools/accuracy:eval_pipeline_builder",
"//tensorflow/lite/tools/accuracy:file_reader_stage",
"//tensorflow/lite/tools/accuracy:run_tflite_model_stage",
"//tensorflow/lite/tools/accuracy:utils",
"//tensorflow/lite/tools/evaluation:utils",
"//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto",
"//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto",
"//tensorflow/lite/tools/evaluation/stages:image_classification_stage",
"@com_google_absl//absl/memory",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:scope",
] + select(
{
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib",
"//tensorflow/core/kernels:android_whole_file_read_ops",
"//tensorflow/core/kernels:android_tensorflow_image_op",
],
"//conditions:default": [
"//tensorflow/core:tensorflow",
"//tensorflow/core:lib_internal",
"//tensorflow/core:framework_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:core_cpu",
],
},
),
],
)
tf_cc_binary(
cc_binary(
name = "imagenet_accuracy_eval",
srcs = ["imagenet_accuracy_eval.cc"],
copts = tflite_copts(),
linkopts = common_linkopts,
deps = [
":imagenet_model_evaluator",
":imagenet_topk_eval",
"@com_google_absl//absl/memory",
"//tensorflow/core:tflite_portable_logging",
"//tensorflow/lite/c:c_api_internal",
"//tensorflow/lite/profiling:time",
"//tensorflow/lite/tools:command_line_flags",
"//tensorflow/lite/tools/accuracy:android_required_build_flags",
"//tensorflow/lite/tools/accuracy:csv_writer",
] + select(
{
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib",
],
"//conditions:default": [
"//tensorflow/core:lib",
"//tensorflow/core:framework_internal",
],
},
),
"//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto",
"@com_google_absl//absl/memory",
],
)
tflite_portable_test_suite()

View File

@ -28,7 +28,7 @@ The binary takes the following parameters:
`mobilenet_labels.txt` where each label is in the same order as the output
1001 dimension tensor.
* `output_path`: `string` \
* `output_file_path`: `string` \
This is the path to the output file. The output is a CSV file that has
top-10 accuracies in each row. Each line of output file is the cumulative
accuracy after processing images in a sorted order. So first line is

View File

@ -15,97 +15,125 @@ limitations under the License.
#include <iomanip>
#include <memory>
#include <mutex> // NOLINT(build/c++11)
#include <string>
#include "absl/memory/memory.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/profiling/time.h"
#include "tensorflow/lite/tools/accuracy/csv_writer.h"
#include "tensorflow/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h"
#include "tensorflow/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/util/command_line_flags.h"
#include "tensorflow/lite/tools/command_line_flags.h"
#include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
namespace tensorflow {
namespace metrics {
namespace {
std::vector<double> GetAccuracies(
const ImagenetTopKAccuracy::AccuracyStats& accuracy_stats) {
std::vector<double> results;
results.reserve(accuracy_stats.number_of_images);
if (accuracy_stats.number_of_images > 0) {
for (int n : accuracy_stats.topk_counts) {
double accuracy = 0;
if (accuracy_stats.number_of_images > 0) {
accuracy = (n * 100.0) / accuracy_stats.number_of_images;
}
results.push_back(accuracy);
}
using ::tflite::evaluation::TopkAccuracyEvalMetrics;
constexpr char kNumThreadsFlag[] = "num_threads";
constexpr char kOutputFilePathFlag[] = "output_file_path";
// TODO(b/130823599): Move to tools/evaluation/stages/topk_accuracy_eval_stage.
// Computes total number of images processed & aggregates Top-K accuracies
// into 'accuracies'.
void AggregateAccuraciesAndNumImages(
int k,
const std::unordered_map<uint64_t, TopkAccuracyEvalMetrics>&
shard_id_accuracy_metrics_map,
const std::unordered_map<uint64_t, int>& shard_id_done_image_count_map,
std::vector<double>* accuracies, int* num_done_images) {
// Total images done.
*num_done_images = 0;
for (auto iter = shard_id_done_image_count_map.begin();
iter != shard_id_done_image_count_map.end(); ++iter) {
*num_done_images += iter->second;
}
// Aggregated accuracies.
for (int i = 0; i < k; ++i) {
double correct_inferences = 0;
double total_inferences = 0;
for (auto iter = shard_id_done_image_count_map.begin();
iter != shard_id_done_image_count_map.end(); ++iter) {
const uint64_t shard_id = iter->first;
const TopkAccuracyEvalMetrics& accuracy_metrics =
shard_id_accuracy_metrics_map.at(shard_id);
const int num_images = iter->second;
correct_inferences += num_images * accuracy_metrics.topk_accuracies(i);
total_inferences += num_images;
}
// Convert to percentage.
accuracies->push_back(100.0 * correct_inferences / total_inferences);
}
return results;
}
} // namespace
// Writes results to a CSV file.
// Writes results to a CSV file & logs progress to standard output with
// `kLogDelayUs` microseconds.
class ResultsWriter : public ImagenetModelEvaluator::Observer {
public:
explicit ResultsWriter(std::unique_ptr<CSVWriter> writer)
: writer_(std::move(writer)) {}
explicit ResultsWriter(int k, std::unique_ptr<CSVWriter> writer)
: k_(k), writer_(std::move(writer)) {}
void OnEvaluationStart(const std::unordered_map<uint64_t, int>&
shard_id_image_count_map) override {}
void OnSingleImageEvaluationComplete(
uint64_t shard_id, const ImagenetTopKAccuracy::AccuracyStats& stats,
const string& image) override;
private:
std::unique_ptr<CSVWriter> writer_ GUARDED_BY(mu_);
mutex mu_;
};
void ResultsWriter::OnSingleImageEvaluationComplete(
uint64_t shard_id, const ImagenetTopKAccuracy::AccuracyStats& stats,
const string& image) {
mutex_lock lock(mu_);
TF_CHECK_OK(writer_->WriteRow(GetAccuracies(stats)));
writer_->Flush();
}
// Logs results to standard output with `kLogDelayUs` microseconds.
class ResultsLogger : public ImagenetModelEvaluator::Observer {
public:
void OnEvaluationStart(const std::unordered_map<uint64_t, int>&
shard_id_image_count_map) override;
void OnSingleImageEvaluationComplete(
uint64_t shard_id, const ImagenetTopKAccuracy::AccuracyStats& stats,
const string& image) override;
void OnSingleImageEvaluationComplete(uint64_t shard_id,
const TopkAccuracyEvalMetrics& metrics,
const string& image) override;
private:
uint64_t last_logged_time_us_ GUARDED_BY(mu_) = 0;
int total_num_images_ GUARDED_BY(mu_);
// For writing to CSV.
int k_;
std::unordered_map<uint64_t, TopkAccuracyEvalMetrics>
shard_id_accuracy_metrics_map_;
std::unordered_map<uint64_t, int> shard_id_done_image_count_map_;
std::unique_ptr<CSVWriter> writer_;
// For logging to stdout.
uint64_t last_logged_time_us_ = 0;
int total_num_images_;
static constexpr int kLogDelayUs = 500 * 1000;
mutex mu_;
std::mutex mu_;
};
void ResultsLogger::OnEvaluationStart(
void ResultsWriter::OnEvaluationStart(
const std::unordered_map<uint64_t, int>& shard_id_image_count_map) {
int total_num_images = 0;
for (const auto& kv : shard_id_image_count_map) {
total_num_images += kv.second;
}
LOG(ERROR) << "Starting model evaluation: " << total_num_images;
mutex_lock lock(mu_);
std::lock_guard<std::mutex> lock(mu_);
total_num_images_ = total_num_images;
}
void ResultsLogger::OnSingleImageEvaluationComplete(
uint64_t shard_id, const ImagenetTopKAccuracy::AccuracyStats& stats,
void ResultsWriter::OnSingleImageEvaluationComplete(
uint64_t shard_id,
const tflite::evaluation::TopkAccuracyEvalMetrics& metrics,
const string& image) {
auto now_us = Env::Default()->NowMicros();
int num_evaluated = stats.number_of_images;
mutex_lock lock(mu_);
std::lock_guard<std::mutex> lock(mu_);
shard_id_done_image_count_map_[shard_id] += 1;
shard_id_accuracy_metrics_map_[shard_id] = metrics;
int num_evaluated;
std::vector<double> total_accuracies;
AggregateAccuraciesAndNumImages(k_, shard_id_accuracy_metrics_map_,
shard_id_done_image_count_map_,
&total_accuracies, &num_evaluated);
if (writer_->WriteRow(total_accuracies) != kTfLiteOk) {
LOG(ERROR) << "Could not write to file";
return;
}
writer_->Flush();
auto now_us = tflite::profiling::time::NowMicros();
if ((now_us - last_logged_time_us_) >= kLogDelayUs) {
last_logged_time_us_ = now_us;
double current_percent = num_evaluated * 100.0 / total_num_images_;
@ -116,44 +144,52 @@ void ResultsLogger::OnSingleImageEvaluationComplete(
}
int Main(int argc, char* argv[]) {
// TODO(shashishekhar): Make this binary configurable and model
// agnostic.
string output_file_path;
int num_threads = 4;
std::vector<Flag> flag_list = {
Flag("output_file_path", &output_file_path, "Path to output file."),
Flag("num_threads", &num_threads, "Number of threads."),
std::vector<tflite::Flag> flag_list = {
tflite::Flag::CreateFlag(kNumThreadsFlag, &num_threads,
"Number of threads."),
tflite::Flag::CreateFlag(kOutputFilePathFlag, &output_file_path,
"Path to output file."),
};
Flags::Parse(&argc, argv, flag_list);
tflite::Flags::Parse(&argc, const_cast<const char**>(argv), flag_list);
std::unique_ptr<ImagenetModelEvaluator> evaluator;
CHECK(!output_file_path.empty()) << "Invalid output file path.";
if (output_file_path.empty()) {
LOG(ERROR) << "Invalid output file path.";
return 0;
}
CHECK(num_threads > 0) << "Invalid number of threads.";
if (num_threads <= 0) {
LOG(ERROR) << "Invalid number of threads.";
return 0;
}
TF_CHECK_OK(
ImagenetModelEvaluator::Create(argc, argv, num_threads, &evaluator));
if (ImagenetModelEvaluator::Create(argc, argv, num_threads, &evaluator) !=
kTfLiteOk)
return 0;
std::ofstream output_stream(output_file_path, std::ios::out);
CHECK(output_stream) << "Unable to open output file path: '"
<< output_file_path << "'";
if (!output_stream) {
LOG(ERROR) << "Unable to open output file path: '" << output_file_path
<< "'";
}
output_stream << std::setprecision(3) << std::fixed;
std::vector<string> columns;
columns.reserve(evaluator->params().num_ranks);
for (int i = 0; i < evaluator->params().num_ranks; i++) {
string column_name = "Top ";
tensorflow::strings::StrAppend(&column_name, i + 1);
std::string column_name = "Top ";
column_name = column_name + std::to_string(i + 1);
columns.push_back(column_name);
}
ResultsWriter results_writer(
evaluator->params().num_ranks,
absl::make_unique<CSVWriter>(columns, &output_stream));
ResultsLogger logger;
evaluator->AddObserver(&results_writer);
evaluator->AddObserver(&logger);
LOG(ERROR) << "Starting evaluation with: " << num_threads << " threads.";
TF_CHECK_OK(evaluator->EvaluateModel());
evaluator->EvaluateModel();
return 0;
}

View File

@ -15,31 +15,34 @@ limitations under the License.
#include "tensorflow/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h"
#include <dirent.h>
#include <fstream>
#include <iomanip>
#include <mutex> // NOLINT(build/c++11)
#include <string>
#include <thread> // NOLINT(build/c++11)
#include <vector>
#include "absl/memory/memory.h"
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/core/lib/core/blocking_counter.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/util/command_line_flags.h"
#include "tensorflow/lite/tools/accuracy/eval_pipeline.h"
#include "tensorflow/lite/tools/accuracy/eval_pipeline_builder.h"
#include "tensorflow/lite/tools/accuracy/file_reader_stage.h"
#include "tensorflow/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h"
#include "tensorflow/lite/tools/accuracy/ilsvrc/inception_preprocessing.h"
#include "tensorflow/lite/tools/accuracy/run_tflite_model_stage.h"
#include "tensorflow/lite/tools/accuracy/utils.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/tools/command_line_flags.h"
#include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
#include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
#include "tensorflow/lite/tools/evaluation/stages/image_classification_stage.h"
#include "tensorflow/lite/tools/evaluation/utils.h"
namespace {
using tensorflow::string;
string StripTrailingSlashes(const string& path) {
constexpr char kNumImagesFlag[] = "num_images";
constexpr char kModelOutputLabelsFlag[] = "model_output_labels";
constexpr char kGroundTruthImagesPathFlag[] = "ground_truth_images_path";
constexpr char kGroundTruthLabelsFlag[] = "ground_truth_labels";
constexpr char kBlacklistFilePathFlag[] = "blacklist_file_path";
constexpr char kModelFileFlag[] = "model_file";
std::string StripTrailingSlashes(const std::string& path) {
int end = path.size();
while (end > 0 && path[end - 1] == '/') {
end--;
@ -47,12 +50,6 @@ string StripTrailingSlashes(const string& path) {
return path.substr(0, end);
}
tensorflow::Tensor CreateStringTensor(const string& value) {
tensorflow::Tensor tensor(tensorflow::DT_STRING, tensorflow::TensorShape({}));
tensor.scalar<string>()() = value;
return tensor;
}
template <typename T>
std::vector<T> GetFirstN(const std::vector<T>& v, int n) {
if (n >= v.size()) return v;
@ -62,7 +59,9 @@ std::vector<T> GetFirstN(const std::vector<T>& v, int n) {
template <typename T>
std::vector<std::vector<T>> Split(const std::vector<T>& v, int n) {
CHECK_GT(n, 0);
if (n <= 0) {
return std::vector<std::vector<T>>();
}
std::vector<std::vector<T>> vecs(n);
int input_index = 0;
int vec_index = 0;
@ -71,7 +70,6 @@ std::vector<std::vector<T>> Split(const std::vector<T>& v, int n) {
vec_index = (vec_index + 1) % n;
input_index++;
}
CHECK_EQ(vecs.size(), n);
return vecs;
}
@ -90,149 +88,124 @@ class CompositeObserver : public ImagenetModelEvaluator::Observer {
void OnEvaluationStart(const std::unordered_map<uint64_t, int>&
shard_id_image_count_map) override {
mutex_lock lock(mu_);
std::lock_guard<std::mutex> lock(mu_);
for (auto observer : observers_) {
observer->OnEvaluationStart(shard_id_image_count_map);
}
}
void OnSingleImageEvaluationComplete(
uint64_t shard_id, const ImagenetTopKAccuracy::AccuracyStats& stats,
const string& image) override {
mutex_lock lock(mu_);
uint64_t shard_id,
const tflite::evaluation::TopkAccuracyEvalMetrics& metrics,
const std::string& image) override {
std::lock_guard<std::mutex> lock(mu_);
for (auto observer : observers_) {
observer->OnSingleImageEvaluationComplete(shard_id, stats, image);
observer->OnSingleImageEvaluationComplete(shard_id, metrics, image);
}
}
private:
const std::vector<ImagenetModelEvaluator::Observer*>& observers_
GUARDED_BY(mu_);
mutex mu_;
const std::vector<ImagenetModelEvaluator::Observer*>& observers_;
std::mutex mu_;
};
/*static*/ Status ImagenetModelEvaluator::Create(
/*static*/ TfLiteStatus ImagenetModelEvaluator::Create(
int argc, char* argv[], int num_threads,
std::unique_ptr<ImagenetModelEvaluator>* model_evaluator) {
Params params;
const std::vector<Flag> flag_list = {
Flag("model_output_labels", &params.model_output_labels_path,
"Path to labels that correspond to output of model."
" E.g. in case of mobilenet, this is the path to label "
"file where each label is in the same order as the output"
" of the model."),
Flag("ground_truth_images_path", &params.ground_truth_images_path,
"Path to ground truth images."),
Flag("ground_truth_labels", &params.ground_truth_labels_path,
"Path to ground truth labels."),
Flag("num_images", &params.number_of_images,
"Number of examples to evaluate, pass 0 for all "
"examples. Default: 100"),
Flag("blacklist_file_path", &params.blacklist_file_path,
"Path to blacklist file (optional)."
"Path to blacklist file where each line is a single integer that is "
"equal to number of blacklisted image."),
Flag("model_file", &params.model_file_path,
"Path to test tflite model file."),
};
const bool parse_result = Flags::Parse(&argc, argv, flag_list);
if (!parse_result)
return errors::InvalidArgument("Invalid command line flags");
::tensorflow::port::InitMain(argv[0], &argc, &argv);
params.number_of_images = 100;
std::vector<tflite::Flag> flag_list = {
tflite::Flag::CreateFlag(kNumImagesFlag, &params.number_of_images,
"Number of examples to evaluate, pass 0 for all "
"examples. Default: 100"),
tflite::Flag::CreateFlag(
kModelOutputLabelsFlag, &params.model_output_labels_path,
TF_RETURN_WITH_CONTEXT_IF_ERROR(
Env::Default()->IsDirectory(params.ground_truth_images_path),
"Invalid ground truth data path.");
TF_RETURN_WITH_CONTEXT_IF_ERROR(
Env::Default()->FileExists(params.ground_truth_labels_path),
"Invalid ground truth labels path.");
TF_RETURN_WITH_CONTEXT_IF_ERROR(
Env::Default()->FileExists(params.model_output_labels_path),
"Invalid model output labels path.");
"Path to labels that correspond to output of model."
" E.g. in case of mobilenet, this is the path to label "
"file where each label is in the same order as the output"
" of the model."),
tflite::Flag::CreateFlag(
kGroundTruthImagesPathFlag, &params.ground_truth_images_path,
if (!params.blacklist_file_path.empty()) {
TF_RETURN_WITH_CONTEXT_IF_ERROR(
Env::Default()->FileExists(params.blacklist_file_path),
"Invalid blacklist path.");
}
"Path to ground truth images. These will be evaluated in "
"alphabetical order of filename"),
tflite::Flag::CreateFlag(
kGroundTruthLabelsFlag, &params.ground_truth_labels_path,
"Path to ground truth labels, corresponding to alphabetical ordering "
"of ground truth images."),
tflite::Flag::CreateFlag(
kBlacklistFilePathFlag, &params.blacklist_file_path,
"Path to blacklist file (optional) where each line is a single "
"integer that is "
"equal to index number of blacklisted image."),
tflite::Flag::CreateFlag(kModelFileFlag, &params.model_file_path,
"Path to test tflite model file.")};
tflite::Flags::Parse(&argc, const_cast<const char**>(argv), flag_list);
if (params.number_of_images < 0) {
return errors::InvalidArgument("Invalid: num_examples");
LOG(ERROR) << "Invalid: num_examples";
return kTfLiteError;
}
utils::ModelInfo model_info;
TF_RETURN_WITH_CONTEXT_IF_ERROR(
utils::GetTFliteModelInfo(params.model_file_path, &model_info),
"Invalid TFLite model.");
*model_evaluator = absl::make_unique<ImagenetModelEvaluator>(
model_info, params, num_threads);
return Status::OK();
*model_evaluator =
absl::make_unique<ImagenetModelEvaluator>(params, num_threads);
return kTfLiteOk;
}
struct ImageLabel {
string image;
string label;
std::string image;
std::string label;
};
Status EvaluateModelForShard(const uint64_t shard_id,
const std::vector<ImageLabel>& image_labels,
const std::vector<string>& model_labels,
const utils::ModelInfo& model_info,
const ImagenetModelEvaluator::Params& params,
ImagenetModelEvaluator::Observer* observer,
ImagenetTopKAccuracy* eval) {
const TensorShape& input_shape = model_info.input_shapes[0];
const int image_height = input_shape.dim_size(1);
const int image_width = input_shape.dim_size(2);
TfLiteStatus EvaluateModelForShard(const uint64_t shard_id,
const std::vector<ImageLabel>& image_labels,
const std::vector<std::string>& model_labels,
const ImagenetModelEvaluator::Params& params,
ImagenetModelEvaluator::Observer* observer,
int num_ranks) {
tflite::evaluation::EvaluationStageConfig eval_config;
eval_config.set_name("image_classification");
auto* classification_params = eval_config.mutable_specification()
->mutable_image_classification_params();
auto* inference_params = classification_params->mutable_inference_params();
inference_params->set_model_file_path(params.model_file_path);
classification_params->mutable_topk_accuracy_eval_params()->set_k(num_ranks);
RunTFLiteModelStage::Params tfl_model_params;
tfl_model_params.model_file_path = params.model_file_path;
tfl_model_params.input_type = {model_info.input_types[0]};
tfl_model_params.output_type = {model_info.input_types[0]};
Scope root = Scope::NewRootScope();
FileReaderStage reader;
InceptionPreprocessingStage inc(image_height, image_width,
model_info.input_types[0]);
RunTFLiteModelStage tfl_model_stage(tfl_model_params);
EvalPipelineBuilder builder;
std::unique_ptr<EvalPipeline> eval_pipeline;
auto build_status = builder.WithInputStage(&reader)
.WithPreprocessingStage(&inc)
.WithRunModelStage(&tfl_model_stage)
.WithAccuracyEval(eval)
.WithInput("input_file", DT_STRING)
.Build(root, &eval_pipeline);
TF_RETURN_WITH_CONTEXT_IF_ERROR(build_status,
"Failure while building eval pipeline.");
std::unique_ptr<Session> session(NewSession(SessionOptions()));
TF_RETURN_IF_ERROR(eval_pipeline->AttachSession(std::move(session)));
tflite::evaluation::ImageClassificationStage eval(eval_config);
eval.SetAllLabels(model_labels);
TF_LITE_ENSURE_STATUS(eval.Init());
for (const auto& image_label : image_labels) {
TF_CHECK_OK(eval_pipeline->Run(CreateStringTensor(image_label.image),
CreateStringTensor(image_label.label)));
eval.SetInputs(image_label.image, image_label.label);
TF_LITE_ENSURE_STATUS(eval.Run());
observer->OnSingleImageEvaluationComplete(
shard_id, eval->GetTopKAccuracySoFar(), image_label.image);
shard_id,
eval.LatestMetrics()
.process_metrics()
.image_classification_metrics()
.topk_accuracy_metrics(),
image_label.image);
}
return Status::OK();
return kTfLiteOk;
}
Status FilterBlackListedImages(const string& blacklist_file_path,
std::vector<ImageLabel>* image_labels) {
// TODO(b/130823599): Move to tools/evaluation/utils.
TfLiteStatus FilterBlackListedImages(const std::string& blacklist_file_path,
std::vector<ImageLabel>* image_labels) {
if (!blacklist_file_path.empty()) {
std::vector<string> lines;
TF_RETURN_IF_ERROR(utils::ReadFileLines(blacklist_file_path, &lines));
std::vector<std::string> lines;
if (!tflite::evaluation::ReadFileLines(blacklist_file_path, &lines)) {
LOG(ERROR) << "Could not read: " << blacklist_file_path;
return kTfLiteError;
}
std::vector<int> blacklist_ids;
blacklist_ids.reserve(lines.size());
// Populate blacklist_ids with indices of images.
std::transform(lines.begin(), lines.end(),
std::back_inserter(blacklist_ids),
[](const string& val) { return std::stoi(val) - 1; });
[](const std::string& val) { return std::stoi(val) - 1; });
std::vector<ImageLabel> filtered_images;
std::sort(blacklist_ids.begin(), blacklist_ids.end());
@ -251,38 +224,51 @@ Status FilterBlackListedImages(const string& blacklist_file_path,
}
if (filtered_images.size() != size_post_filtering) {
return errors::Internal("Invalid number of filtered images");
LOG(ERROR) << "Invalid number of filtered images";
return kTfLiteError;
}
*image_labels = filtered_images;
}
return Status::OK();
return kTfLiteOk;
}
Status ImagenetModelEvaluator::EvaluateModel() const {
if (model_info_.input_shapes.size() != 1) {
return errors::InvalidArgument("Invalid input shape");
// TODO(b/130823599): Move to tools/evaluation/utils.
TfLiteStatus GetSortedFileNames(const std::string dir_path,
std::vector<std::string>* result) {
DIR* dir;
struct dirent* ent;
if (result == nullptr) {
LOG(ERROR) << "result cannot be nullptr";
return kTfLiteError;
}
const TensorShape& input_shape = model_info_.input_shapes[0];
// Input should be of the shape {1, height, width, 3}
if (input_shape.dims() != 4 || input_shape.dim_size(3) != 3) {
return errors::InvalidArgument("Invalid input shape for the model.");
if ((dir = opendir(dir_path.c_str())) != nullptr) {
while ((ent = readdir(dir)) != nullptr) {
std::string filename(std::string(ent->d_name));
if (filename.size() <= 2) continue;
result->emplace_back(dir_path + "/" + filename);
}
closedir(dir);
} else {
LOG(ERROR) << "Could not open dir: " << dir_path;
return kTfLiteError;
}
std::sort(result->begin(), result->end());
return kTfLiteOk;
}
string data_path =
TfLiteStatus ImagenetModelEvaluator::EvaluateModel() const {
const std::string data_path =
StripTrailingSlashes(params_.ground_truth_images_path) + "/";
const string imagenet_file_pattern = data_path + kImagenetFilePattern;
std::vector<string> image_files;
TF_CHECK_OK(
Env::Default()->GetMatchingPaths(imagenet_file_pattern, &image_files));
std::vector<std::string> image_files;
TF_LITE_ENSURE_STATUS(GetSortedFileNames(data_path, &image_files));
std::vector<string> ground_truth_image_labels;
TF_CHECK_OK(utils::ReadFileLines(params_.ground_truth_labels_path,
&ground_truth_image_labels));
CHECK_EQ(image_files.size(), ground_truth_image_labels.size());
// Process files in filename sorted order.
std::sort(image_files.begin(), image_files.end());
if (!tflite::evaluation::ReadFileLines(params_.ground_truth_labels_path,
&ground_truth_image_labels))
return kTfLiteError;
if (image_files.size() != ground_truth_image_labels.size()) {
LOG(ERROR) << "Images and ground truth labels don't match";
return kTfLiteError;
}
std::vector<ImageLabel> image_labels;
image_labels.reserve(image_files.size());
@ -291,56 +277,55 @@ Status ImagenetModelEvaluator::EvaluateModel() const {
}
// Filter any blacklisted images.
TF_CHECK_OK(
FilterBlackListedImages(params_.blacklist_file_path, &image_labels));
if (FilterBlackListedImages(params_.blacklist_file_path, &image_labels) !=
kTfLiteOk) {
LOG(ERROR) << "Could not filter by blacklist";
return kTfLiteError;
}
if (params_.number_of_images > 0) {
image_labels = GetFirstN(image_labels, params_.number_of_images);
}
std::vector<string> model_labels;
TF_RETURN_IF_ERROR(
utils::ReadFileLines(params_.model_output_labels_path, &model_labels));
if (model_labels.size() != 1001) {
return errors::InvalidArgument("Invalid number of labels: ",
model_labels.size());
if (!tflite::evaluation::ReadFileLines(params_.model_output_labels_path,
&model_labels)) {
LOG(ERROR) << "Could not read: " << params_.model_output_labels_path;
return kTfLiteError;
}
if (model_labels.size() != 1001) {
LOG(ERROR) << "Invalid number of labels: " << model_labels.size();
return kTfLiteError;
}
ImagenetTopKAccuracy eval(model_labels, params_.num_ranks);
auto img_labels = Split(image_labels, num_threads_);
BlockingCounter counter(num_threads_);
CompositeObserver observer(observers_);
::tensorflow::thread::ThreadPool pool(Env::Default(), "evaluation_pool",
num_threads_);
std::vector<std::thread> thread_pool;
bool all_okay = true;
std::unordered_map<uint64_t, int> shard_id_image_count_map;
std::vector<std::function<void()>> thread_funcs;
thread_funcs.reserve(num_threads_);
thread_pool.reserve(num_threads_);
for (int i = 0; i < num_threads_; i++) {
const auto& image_label = img_labels[i];
const uint64_t shard_id = i + 1;
shard_id_image_count_map[shard_id] = image_label.size();
auto func = [shard_id, &image_label, &model_labels, this, &observer, &eval,
&counter]() {
TF_CHECK_OK(EvaluateModelForShard(shard_id, image_label, model_labels,
model_info_, params_, &observer,
&eval));
counter.DecrementCount();
auto func = [shard_id, &image_label, &model_labels, this, &observer,
&all_okay]() {
if (EvaluateModelForShard(shard_id, image_label, model_labels, params_,
&observer, params_.num_ranks) != kTfLiteOk) {
all_okay = all_okay && false;
}
};
thread_funcs.push_back(func);
thread_pool.push_back(std::thread(func));
}
observer.OnEvaluationStart(shard_id_image_count_map);
for (const auto& func : thread_funcs) {
pool.Schedule(func);
for (auto& thread : thread_pool) {
thread.join();
}
counter.Wait();
return Status::OK();
return kTfLiteOk;
}
} // namespace metrics

View File

@ -13,15 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_TOOLS_ACCURACY_IMAGENET_MODEL_EVALUATOR_H_
#define TENSORFLOW_LITE_TOOLS_ACCURACY_IMAGENET_MODEL_EVALUATOR_H_
#ifndef TENSORFLOW_LITE_TOOLS_ACCURACY_ILSVRC_IMAGENET_MODEL_EVALUATOR_H_
#define TENSORFLOW_LITE_TOOLS_ACCURACY_ILSVRC_IMAGENET_MODEL_EVALUATOR_H_
#include <string>
#include <vector>
#include "tensorflow/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h"
#include "tensorflow/lite/tools/accuracy/utils.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
namespace tensorflow {
namespace metrics {
@ -42,26 +41,26 @@ class ImagenetModelEvaluator {
public:
struct Params {
// Path to ground truth images.
string ground_truth_images_path;
std::string ground_truth_images_path;
// Path to labels file for ground truth image.
// This file should be generated with the scripts.
string ground_truth_labels_path;
std::string ground_truth_labels_path;
// This is word labels generated by the model. The category
// indices of output probabilities generated by the model maybe different
// from the indices in the imagenet dataset.
string model_output_labels_path;
std::string model_output_labels_path;
// Path to the model file.
string model_file_path;
std::string model_file_path;
// Path to black list file. 1762 images were blacklisted from
// original ILSVRC dataset. This black list file is present in
// ILSVRC2014 devkit. Please refer to readme.txt of the ILSVRC2014
// devkit for details.
// This file is a list of image indices in a sorted order.
string blacklist_file_path;
std::string blacklist_file_path;
// The maximum number of images to calculate accuracy.
// 0 means all images, a positive number means only the specified
@ -90,19 +89,20 @@ class ImagenetModelEvaluator {
// Called when evaluation was complete for `image`.
virtual void OnSingleImageEvaluationComplete(
uint64_t shard_id, const ImagenetTopKAccuracy::AccuracyStats& stats,
uint64_t shard_id,
const tflite::evaluation::TopkAccuracyEvalMetrics& metrics,
const string& image) = 0;
virtual ~Observer() = default;
};
ImagenetModelEvaluator(const utils::ModelInfo& model_info,
const Params& params, const int num_threads)
: model_info_(model_info), params_(params), num_threads_(num_threads) {}
ImagenetModelEvaluator(const Params& params, const int num_threads)
: params_(params), num_threads_(num_threads) {}
// Factory method to create the evaluator by parsing command line arguments.
static Status Create(int argc, char* argv[], int num_threads,
std::unique_ptr<ImagenetModelEvaluator>* evaluator);
static TfLiteStatus Create(
int argc, char* argv[], int num_threads,
std::unique_ptr<ImagenetModelEvaluator>* evaluator);
// Adds an observer that can observe evaluation events..
void AddObserver(Observer* observer) { observers_.push_back(observer); }
@ -110,10 +110,9 @@ class ImagenetModelEvaluator {
const Params& params() const { return params_; }
// Evaluates the provided model over the dataset.
Status EvaluateModel() const;
TfLiteStatus EvaluateModel() const;
private:
const utils::ModelInfo model_info_;
const Params params_;
const int num_threads_;
std::vector<Observer*> observers_;

View File

@ -6,8 +6,7 @@ licenses(["notice"]) # Apache 2.0
load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite")
load("//tensorflow/lite:build_def.bzl", "tflite_copts")
load("//tensorflow/lite:build_def.bzl", "tflite_linkopts")
load("//tensorflow/lite:build_def.bzl", "tflite_copts", "tflite_linkopts")
common_copts = ["-Wall"] + tflite_copts()
@ -70,31 +69,9 @@ cc_test(
],
deps = [
":benchmark_tflite_model_lib",
":command_line_flags",
"//tensorflow/lite:framework",
"//tensorflow/lite/testing:util",
"@com_google_googletest//:gtest",
],
)
cc_library(
name = "command_line_flags",
srcs = ["command_line_flags.cc"],
hdrs = ["command_line_flags.h"],
copts = common_copts,
)
cc_test(
name = "command_line_flags_test",
srcs = ["command_line_flags_test.cc"],
copts = common_copts,
tags = [
"tflite_not_portable_ios", # TODO(b/117786830)
],
visibility = ["//visibility:private"],
deps = [
":command_line_flags",
"//tensorflow/lite/testing:util",
"//tensorflow/lite/tools:command_line_flags",
"@com_google_googletest//:gtest",
],
)
@ -143,11 +120,11 @@ cc_library(
copts = common_copts,
deps = [
":benchmark_params",
":command_line_flags",
":logging",
"//tensorflow/core:stats_calculator_portable",
"//tensorflow/lite:framework",
"//tensorflow/lite/profiling:time",
"//tensorflow/lite/tools:command_line_flags",
],
)

View File

@ -23,9 +23,9 @@ limitations under the License.
#include <unordered_set>
#include <vector>
#include "tensorflow/lite/tools/benchmark/benchmark_params.h"
#include "tensorflow/lite/tools/benchmark/command_line_flags.h"
#include "tensorflow/core/util/stats_calculator.h"
#include "tensorflow/lite/tools/benchmark/benchmark_params.h"
#include "tensorflow/lite/tools/command_line_flags.h"
namespace tflite {
namespace benchmark {

View File

@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/testing/util.h"
#include "tensorflow/lite/tools/benchmark/benchmark_tflite_model.h"
#include "tensorflow/lite/tools/benchmark/command_line_flags.h"
#include "tensorflow/lite/tools/command_line_flags.h"
namespace {
const std::string* g_model_path = nullptr;

View File

@ -10,7 +10,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/tools/benchmark/command_line_flags.h"
#include "tensorflow/lite/tools/command_line_flags.h"
#include <cstring>
#include <sstream>

View File

@ -13,7 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/tools/benchmark/command_line_flags.h"
#include "tensorflow/lite/tools/command_line_flags.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/lite/testing/util.h"