Refactors ImagenetModelEvaluator to use ImageClassificationStage.
PiperOrigin-RevId: 244374565
This commit is contained in:
parent
c767cbf14c
commit
37b20d191a
@ -6,6 +6,7 @@ licenses(["notice"]) # Apache 2.0
|
||||
|
||||
load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
|
||||
load("//tensorflow/lite:build_def.bzl", "tflite_copts")
|
||||
|
||||
common_copts = ["-Wall"]
|
||||
|
||||
@ -96,4 +97,23 @@ cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "command_line_flags",
|
||||
srcs = ["command_line_flags.cc"],
|
||||
hdrs = ["command_line_flags.h"],
|
||||
copts = tflite_copts(),
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "command_line_flags_test",
|
||||
srcs = ["command_line_flags_test.cc"],
|
||||
copts = tflite_copts(),
|
||||
visibility = ["//visibility:private"],
|
||||
deps = [
|
||||
":command_line_flags",
|
||||
"//tensorflow/lite/testing:util",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
||||
tflite_portable_test_suite()
|
||||
|
@ -316,16 +316,10 @@ cc_library(
|
||||
name = "csv_writer",
|
||||
hdrs = ["csv_writer.h"],
|
||||
copts = tflite_copts(),
|
||||
deps = select(
|
||||
{
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib",
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
},
|
||||
),
|
||||
deps = [
|
||||
"//tensorflow/core:tflite_portable_logging",
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
],
|
||||
)
|
||||
|
||||
tflite_portable_test_suite()
|
||||
|
@ -19,8 +19,8 @@ limitations under the License.
|
||||
#include <fstream>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace metrics {
|
||||
@ -36,14 +36,17 @@ class CSVWriter {
|
||||
public:
|
||||
CSVWriter(const std::vector<string>& columns, std::ofstream* output_stream)
|
||||
: num_columns_(columns.size()), output_stream_(output_stream) {
|
||||
TF_CHECK_OK(WriteRow(columns, output_stream_));
|
||||
if (WriteRow(columns, output_stream_) != kTfLiteOk) {
|
||||
LOG(ERROR) << "Could not write column names to file";
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status WriteRow(const std::vector<T>& values) {
|
||||
TfLiteStatus WriteRow(const std::vector<T>& values) {
|
||||
if (values.size() != num_columns_) {
|
||||
return errors::InvalidArgument("Invalid size for row:", values.size(),
|
||||
" expected: ", num_columns_);
|
||||
LOG(ERROR) << "Invalid size for row:" << values.size()
|
||||
<< " expected: " << num_columns_;
|
||||
return kTfLiteError;
|
||||
}
|
||||
return WriteRow(values, output_stream_);
|
||||
}
|
||||
@ -54,8 +57,8 @@ class CSVWriter {
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
static Status WriteRow(const std::vector<T>& values,
|
||||
std::ofstream* output_stream) {
|
||||
static TfLiteStatus WriteRow(const std::vector<T>& values,
|
||||
std::ofstream* output_stream) {
|
||||
bool first = true;
|
||||
for (const auto& v : values) {
|
||||
if (!first) {
|
||||
@ -67,9 +70,10 @@ class CSVWriter {
|
||||
}
|
||||
(*output_stream) << "\n";
|
||||
if (!output_stream->good()) {
|
||||
return errors::Internal("Writing to stream failed.");
|
||||
LOG(ERROR) << "Writing to stream failed.";
|
||||
return kTfLiteError;
|
||||
}
|
||||
return Status::OK();
|
||||
return kTfLiteOk;
|
||||
}
|
||||
const size_t num_columns_;
|
||||
std::ofstream* output_stream_;
|
||||
|
@ -126,58 +126,34 @@ cc_library(
|
||||
hdrs = ["imagenet_model_evaluator.h"],
|
||||
copts = tflite_copts(),
|
||||
deps = [
|
||||
":imagenet_topk_eval",
|
||||
":inception_preprocessing",
|
||||
"//tensorflow/core:tflite_portable_logging",
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
"//tensorflow/lite/tools:command_line_flags",
|
||||
"//tensorflow/lite/tools/accuracy:android_required_build_flags",
|
||||
"//tensorflow/lite/tools/accuracy:eval_pipeline",
|
||||
"//tensorflow/lite/tools/accuracy:eval_pipeline_builder",
|
||||
"//tensorflow/lite/tools/accuracy:file_reader_stage",
|
||||
"//tensorflow/lite/tools/accuracy:run_tflite_model_stage",
|
||||
"//tensorflow/lite/tools/accuracy:utils",
|
||||
"//tensorflow/lite/tools/evaluation:utils",
|
||||
"//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto",
|
||||
"//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto",
|
||||
"//tensorflow/lite/tools/evaluation/stages:image_classification_stage",
|
||||
"@com_google_absl//absl/memory",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:scope",
|
||||
] + select(
|
||||
{
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib",
|
||||
"//tensorflow/core/kernels:android_whole_file_read_ops",
|
||||
"//tensorflow/core/kernels:android_tensorflow_image_op",
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:tensorflow",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:core_cpu",
|
||||
],
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_binary(
|
||||
cc_binary(
|
||||
name = "imagenet_accuracy_eval",
|
||||
srcs = ["imagenet_accuracy_eval.cc"],
|
||||
copts = tflite_copts(),
|
||||
linkopts = common_linkopts,
|
||||
deps = [
|
||||
":imagenet_model_evaluator",
|
||||
":imagenet_topk_eval",
|
||||
"@com_google_absl//absl/memory",
|
||||
"//tensorflow/core:tflite_portable_logging",
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
"//tensorflow/lite/profiling:time",
|
||||
"//tensorflow/lite/tools:command_line_flags",
|
||||
"//tensorflow/lite/tools/accuracy:android_required_build_flags",
|
||||
"//tensorflow/lite/tools/accuracy:csv_writer",
|
||||
] + select(
|
||||
{
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib",
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:framework_internal",
|
||||
],
|
||||
},
|
||||
),
|
||||
"//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto",
|
||||
"@com_google_absl//absl/memory",
|
||||
],
|
||||
)
|
||||
|
||||
tflite_portable_test_suite()
|
||||
|
@ -28,7 +28,7 @@ The binary takes the following parameters:
|
||||
`mobilenet_labels.txt` where each label is in the same order as the output
|
||||
1001 dimension tensor.
|
||||
|
||||
* `output_path`: `string` \
|
||||
* `output_file_path`: `string` \
|
||||
This is the path to the output file. The output is a CSV file that has
|
||||
top-10 accuracies in each row. Each line of output file is the cumulative
|
||||
accuracy after processing images in a sorted order. So first line is
|
||||
|
@ -15,97 +15,125 @@ limitations under the License.
|
||||
|
||||
#include <iomanip>
|
||||
#include <memory>
|
||||
#include <mutex> // NOLINT(build/c++11)
|
||||
#include <string>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/profiling/time.h"
|
||||
#include "tensorflow/lite/tools/accuracy/csv_writer.h"
|
||||
#include "tensorflow/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h"
|
||||
#include "tensorflow/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
#include "tensorflow/lite/tools/command_line_flags.h"
|
||||
#include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace metrics {
|
||||
|
||||
namespace {
|
||||
|
||||
std::vector<double> GetAccuracies(
|
||||
const ImagenetTopKAccuracy::AccuracyStats& accuracy_stats) {
|
||||
std::vector<double> results;
|
||||
results.reserve(accuracy_stats.number_of_images);
|
||||
if (accuracy_stats.number_of_images > 0) {
|
||||
for (int n : accuracy_stats.topk_counts) {
|
||||
double accuracy = 0;
|
||||
if (accuracy_stats.number_of_images > 0) {
|
||||
accuracy = (n * 100.0) / accuracy_stats.number_of_images;
|
||||
}
|
||||
results.push_back(accuracy);
|
||||
}
|
||||
using ::tflite::evaluation::TopkAccuracyEvalMetrics;
|
||||
|
||||
constexpr char kNumThreadsFlag[] = "num_threads";
|
||||
constexpr char kOutputFilePathFlag[] = "output_file_path";
|
||||
|
||||
// TODO(b/130823599): Move to tools/evaluation/stages/topk_accuracy_eval_stage.
|
||||
// Computes total number of images processed & aggregates Top-K accuracies
|
||||
// into 'accuracies'.
|
||||
void AggregateAccuraciesAndNumImages(
|
||||
int k,
|
||||
const std::unordered_map<uint64_t, TopkAccuracyEvalMetrics>&
|
||||
shard_id_accuracy_metrics_map,
|
||||
const std::unordered_map<uint64_t, int>& shard_id_done_image_count_map,
|
||||
std::vector<double>* accuracies, int* num_done_images) {
|
||||
// Total images done.
|
||||
*num_done_images = 0;
|
||||
for (auto iter = shard_id_done_image_count_map.begin();
|
||||
iter != shard_id_done_image_count_map.end(); ++iter) {
|
||||
*num_done_images += iter->second;
|
||||
}
|
||||
|
||||
// Aggregated accuracies.
|
||||
for (int i = 0; i < k; ++i) {
|
||||
double correct_inferences = 0;
|
||||
double total_inferences = 0;
|
||||
for (auto iter = shard_id_done_image_count_map.begin();
|
||||
iter != shard_id_done_image_count_map.end(); ++iter) {
|
||||
const uint64_t shard_id = iter->first;
|
||||
const TopkAccuracyEvalMetrics& accuracy_metrics =
|
||||
shard_id_accuracy_metrics_map.at(shard_id);
|
||||
const int num_images = iter->second;
|
||||
correct_inferences += num_images * accuracy_metrics.topk_accuracies(i);
|
||||
total_inferences += num_images;
|
||||
}
|
||||
// Convert to percentage.
|
||||
accuracies->push_back(100.0 * correct_inferences / total_inferences);
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Writes results to a CSV file.
|
||||
// Writes results to a CSV file & logs progress to standard output with
|
||||
// `kLogDelayUs` microseconds.
|
||||
class ResultsWriter : public ImagenetModelEvaluator::Observer {
|
||||
public:
|
||||
explicit ResultsWriter(std::unique_ptr<CSVWriter> writer)
|
||||
: writer_(std::move(writer)) {}
|
||||
explicit ResultsWriter(int k, std::unique_ptr<CSVWriter> writer)
|
||||
: k_(k), writer_(std::move(writer)) {}
|
||||
|
||||
void OnEvaluationStart(const std::unordered_map<uint64_t, int>&
|
||||
shard_id_image_count_map) override {}
|
||||
|
||||
void OnSingleImageEvaluationComplete(
|
||||
uint64_t shard_id, const ImagenetTopKAccuracy::AccuracyStats& stats,
|
||||
const string& image) override;
|
||||
|
||||
private:
|
||||
std::unique_ptr<CSVWriter> writer_ GUARDED_BY(mu_);
|
||||
mutex mu_;
|
||||
};
|
||||
|
||||
void ResultsWriter::OnSingleImageEvaluationComplete(
|
||||
uint64_t shard_id, const ImagenetTopKAccuracy::AccuracyStats& stats,
|
||||
const string& image) {
|
||||
mutex_lock lock(mu_);
|
||||
TF_CHECK_OK(writer_->WriteRow(GetAccuracies(stats)));
|
||||
writer_->Flush();
|
||||
}
|
||||
|
||||
// Logs results to standard output with `kLogDelayUs` microseconds.
|
||||
class ResultsLogger : public ImagenetModelEvaluator::Observer {
|
||||
public:
|
||||
void OnEvaluationStart(const std::unordered_map<uint64_t, int>&
|
||||
shard_id_image_count_map) override;
|
||||
|
||||
void OnSingleImageEvaluationComplete(
|
||||
uint64_t shard_id, const ImagenetTopKAccuracy::AccuracyStats& stats,
|
||||
const string& image) override;
|
||||
void OnSingleImageEvaluationComplete(uint64_t shard_id,
|
||||
const TopkAccuracyEvalMetrics& metrics,
|
||||
const string& image) override;
|
||||
|
||||
private:
|
||||
uint64_t last_logged_time_us_ GUARDED_BY(mu_) = 0;
|
||||
int total_num_images_ GUARDED_BY(mu_);
|
||||
// For writing to CSV.
|
||||
int k_;
|
||||
std::unordered_map<uint64_t, TopkAccuracyEvalMetrics>
|
||||
shard_id_accuracy_metrics_map_;
|
||||
std::unordered_map<uint64_t, int> shard_id_done_image_count_map_;
|
||||
std::unique_ptr<CSVWriter> writer_;
|
||||
|
||||
// For logging to stdout.
|
||||
uint64_t last_logged_time_us_ = 0;
|
||||
int total_num_images_;
|
||||
static constexpr int kLogDelayUs = 500 * 1000;
|
||||
mutex mu_;
|
||||
|
||||
std::mutex mu_;
|
||||
};
|
||||
|
||||
void ResultsLogger::OnEvaluationStart(
|
||||
void ResultsWriter::OnEvaluationStart(
|
||||
const std::unordered_map<uint64_t, int>& shard_id_image_count_map) {
|
||||
int total_num_images = 0;
|
||||
for (const auto& kv : shard_id_image_count_map) {
|
||||
total_num_images += kv.second;
|
||||
}
|
||||
LOG(ERROR) << "Starting model evaluation: " << total_num_images;
|
||||
mutex_lock lock(mu_);
|
||||
std::lock_guard<std::mutex> lock(mu_);
|
||||
total_num_images_ = total_num_images;
|
||||
}
|
||||
|
||||
void ResultsLogger::OnSingleImageEvaluationComplete(
|
||||
uint64_t shard_id, const ImagenetTopKAccuracy::AccuracyStats& stats,
|
||||
void ResultsWriter::OnSingleImageEvaluationComplete(
|
||||
uint64_t shard_id,
|
||||
const tflite::evaluation::TopkAccuracyEvalMetrics& metrics,
|
||||
const string& image) {
|
||||
auto now_us = Env::Default()->NowMicros();
|
||||
int num_evaluated = stats.number_of_images;
|
||||
mutex_lock lock(mu_);
|
||||
std::lock_guard<std::mutex> lock(mu_);
|
||||
shard_id_done_image_count_map_[shard_id] += 1;
|
||||
shard_id_accuracy_metrics_map_[shard_id] = metrics;
|
||||
|
||||
int num_evaluated;
|
||||
std::vector<double> total_accuracies;
|
||||
AggregateAccuraciesAndNumImages(k_, shard_id_accuracy_metrics_map_,
|
||||
shard_id_done_image_count_map_,
|
||||
&total_accuracies, &num_evaluated);
|
||||
if (writer_->WriteRow(total_accuracies) != kTfLiteOk) {
|
||||
LOG(ERROR) << "Could not write to file";
|
||||
return;
|
||||
}
|
||||
writer_->Flush();
|
||||
|
||||
auto now_us = tflite::profiling::time::NowMicros();
|
||||
if ((now_us - last_logged_time_us_) >= kLogDelayUs) {
|
||||
last_logged_time_us_ = now_us;
|
||||
double current_percent = num_evaluated * 100.0 / total_num_images_;
|
||||
@ -116,44 +144,52 @@ void ResultsLogger::OnSingleImageEvaluationComplete(
|
||||
}
|
||||
|
||||
int Main(int argc, char* argv[]) {
|
||||
// TODO(shashishekhar): Make this binary configurable and model
|
||||
// agnostic.
|
||||
string output_file_path;
|
||||
int num_threads = 4;
|
||||
std::vector<Flag> flag_list = {
|
||||
Flag("output_file_path", &output_file_path, "Path to output file."),
|
||||
Flag("num_threads", &num_threads, "Number of threads."),
|
||||
std::vector<tflite::Flag> flag_list = {
|
||||
tflite::Flag::CreateFlag(kNumThreadsFlag, &num_threads,
|
||||
"Number of threads."),
|
||||
tflite::Flag::CreateFlag(kOutputFilePathFlag, &output_file_path,
|
||||
"Path to output file."),
|
||||
};
|
||||
Flags::Parse(&argc, argv, flag_list);
|
||||
tflite::Flags::Parse(&argc, const_cast<const char**>(argv), flag_list);
|
||||
|
||||
std::unique_ptr<ImagenetModelEvaluator> evaluator;
|
||||
CHECK(!output_file_path.empty()) << "Invalid output file path.";
|
||||
if (output_file_path.empty()) {
|
||||
LOG(ERROR) << "Invalid output file path.";
|
||||
return 0;
|
||||
}
|
||||
|
||||
CHECK(num_threads > 0) << "Invalid number of threads.";
|
||||
if (num_threads <= 0) {
|
||||
LOG(ERROR) << "Invalid number of threads.";
|
||||
return 0;
|
||||
}
|
||||
|
||||
TF_CHECK_OK(
|
||||
ImagenetModelEvaluator::Create(argc, argv, num_threads, &evaluator));
|
||||
if (ImagenetModelEvaluator::Create(argc, argv, num_threads, &evaluator) !=
|
||||
kTfLiteOk)
|
||||
return 0;
|
||||
|
||||
std::ofstream output_stream(output_file_path, std::ios::out);
|
||||
CHECK(output_stream) << "Unable to open output file path: '"
|
||||
<< output_file_path << "'";
|
||||
if (!output_stream) {
|
||||
LOG(ERROR) << "Unable to open output file path: '" << output_file_path
|
||||
<< "'";
|
||||
}
|
||||
|
||||
output_stream << std::setprecision(3) << std::fixed;
|
||||
std::vector<string> columns;
|
||||
columns.reserve(evaluator->params().num_ranks);
|
||||
for (int i = 0; i < evaluator->params().num_ranks; i++) {
|
||||
string column_name = "Top ";
|
||||
tensorflow::strings::StrAppend(&column_name, i + 1);
|
||||
std::string column_name = "Top ";
|
||||
column_name = column_name + std::to_string(i + 1);
|
||||
columns.push_back(column_name);
|
||||
}
|
||||
|
||||
ResultsWriter results_writer(
|
||||
evaluator->params().num_ranks,
|
||||
absl::make_unique<CSVWriter>(columns, &output_stream));
|
||||
ResultsLogger logger;
|
||||
evaluator->AddObserver(&results_writer);
|
||||
evaluator->AddObserver(&logger);
|
||||
LOG(ERROR) << "Starting evaluation with: " << num_threads << " threads.";
|
||||
TF_CHECK_OK(evaluator->EvaluateModel());
|
||||
evaluator->EvaluateModel();
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
@ -15,31 +15,34 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h"
|
||||
|
||||
#include <dirent.h>
|
||||
|
||||
#include <fstream>
|
||||
#include <iomanip>
|
||||
#include <mutex> // NOLINT(build/c++11)
|
||||
#include <string>
|
||||
#include <thread> // NOLINT(build/c++11)
|
||||
#include <vector>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/cc/framework/scope.h"
|
||||
#include "tensorflow/core/lib/core/blocking_counter.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
#include "tensorflow/lite/tools/accuracy/eval_pipeline.h"
|
||||
#include "tensorflow/lite/tools/accuracy/eval_pipeline_builder.h"
|
||||
#include "tensorflow/lite/tools/accuracy/file_reader_stage.h"
|
||||
#include "tensorflow/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h"
|
||||
#include "tensorflow/lite/tools/accuracy/ilsvrc/inception_preprocessing.h"
|
||||
#include "tensorflow/lite/tools/accuracy/run_tflite_model_stage.h"
|
||||
#include "tensorflow/lite/tools/accuracy/utils.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/tools/command_line_flags.h"
|
||||
#include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
|
||||
#include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
|
||||
#include "tensorflow/lite/tools/evaluation/stages/image_classification_stage.h"
|
||||
#include "tensorflow/lite/tools/evaluation/utils.h"
|
||||
|
||||
namespace {
|
||||
using tensorflow::string;
|
||||
|
||||
string StripTrailingSlashes(const string& path) {
|
||||
constexpr char kNumImagesFlag[] = "num_images";
|
||||
constexpr char kModelOutputLabelsFlag[] = "model_output_labels";
|
||||
constexpr char kGroundTruthImagesPathFlag[] = "ground_truth_images_path";
|
||||
constexpr char kGroundTruthLabelsFlag[] = "ground_truth_labels";
|
||||
constexpr char kBlacklistFilePathFlag[] = "blacklist_file_path";
|
||||
constexpr char kModelFileFlag[] = "model_file";
|
||||
|
||||
std::string StripTrailingSlashes(const std::string& path) {
|
||||
int end = path.size();
|
||||
while (end > 0 && path[end - 1] == '/') {
|
||||
end--;
|
||||
@ -47,12 +50,6 @@ string StripTrailingSlashes(const string& path) {
|
||||
return path.substr(0, end);
|
||||
}
|
||||
|
||||
tensorflow::Tensor CreateStringTensor(const string& value) {
|
||||
tensorflow::Tensor tensor(tensorflow::DT_STRING, tensorflow::TensorShape({}));
|
||||
tensor.scalar<string>()() = value;
|
||||
return tensor;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> GetFirstN(const std::vector<T>& v, int n) {
|
||||
if (n >= v.size()) return v;
|
||||
@ -62,7 +59,9 @@ std::vector<T> GetFirstN(const std::vector<T>& v, int n) {
|
||||
|
||||
template <typename T>
|
||||
std::vector<std::vector<T>> Split(const std::vector<T>& v, int n) {
|
||||
CHECK_GT(n, 0);
|
||||
if (n <= 0) {
|
||||
return std::vector<std::vector<T>>();
|
||||
}
|
||||
std::vector<std::vector<T>> vecs(n);
|
||||
int input_index = 0;
|
||||
int vec_index = 0;
|
||||
@ -71,7 +70,6 @@ std::vector<std::vector<T>> Split(const std::vector<T>& v, int n) {
|
||||
vec_index = (vec_index + 1) % n;
|
||||
input_index++;
|
||||
}
|
||||
CHECK_EQ(vecs.size(), n);
|
||||
return vecs;
|
||||
}
|
||||
|
||||
@ -90,149 +88,124 @@ class CompositeObserver : public ImagenetModelEvaluator::Observer {
|
||||
|
||||
void OnEvaluationStart(const std::unordered_map<uint64_t, int>&
|
||||
shard_id_image_count_map) override {
|
||||
mutex_lock lock(mu_);
|
||||
std::lock_guard<std::mutex> lock(mu_);
|
||||
for (auto observer : observers_) {
|
||||
observer->OnEvaluationStart(shard_id_image_count_map);
|
||||
}
|
||||
}
|
||||
|
||||
void OnSingleImageEvaluationComplete(
|
||||
uint64_t shard_id, const ImagenetTopKAccuracy::AccuracyStats& stats,
|
||||
const string& image) override {
|
||||
mutex_lock lock(mu_);
|
||||
uint64_t shard_id,
|
||||
const tflite::evaluation::TopkAccuracyEvalMetrics& metrics,
|
||||
const std::string& image) override {
|
||||
std::lock_guard<std::mutex> lock(mu_);
|
||||
for (auto observer : observers_) {
|
||||
observer->OnSingleImageEvaluationComplete(shard_id, stats, image);
|
||||
observer->OnSingleImageEvaluationComplete(shard_id, metrics, image);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
const std::vector<ImagenetModelEvaluator::Observer*>& observers_
|
||||
GUARDED_BY(mu_);
|
||||
mutex mu_;
|
||||
const std::vector<ImagenetModelEvaluator::Observer*>& observers_;
|
||||
std::mutex mu_;
|
||||
};
|
||||
|
||||
/*static*/ Status ImagenetModelEvaluator::Create(
|
||||
/*static*/ TfLiteStatus ImagenetModelEvaluator::Create(
|
||||
int argc, char* argv[], int num_threads,
|
||||
std::unique_ptr<ImagenetModelEvaluator>* model_evaluator) {
|
||||
Params params;
|
||||
const std::vector<Flag> flag_list = {
|
||||
Flag("model_output_labels", ¶ms.model_output_labels_path,
|
||||
"Path to labels that correspond to output of model."
|
||||
" E.g. in case of mobilenet, this is the path to label "
|
||||
"file where each label is in the same order as the output"
|
||||
" of the model."),
|
||||
Flag("ground_truth_images_path", ¶ms.ground_truth_images_path,
|
||||
"Path to ground truth images."),
|
||||
Flag("ground_truth_labels", ¶ms.ground_truth_labels_path,
|
||||
"Path to ground truth labels."),
|
||||
Flag("num_images", ¶ms.number_of_images,
|
||||
"Number of examples to evaluate, pass 0 for all "
|
||||
"examples. Default: 100"),
|
||||
Flag("blacklist_file_path", ¶ms.blacklist_file_path,
|
||||
"Path to blacklist file (optional)."
|
||||
"Path to blacklist file where each line is a single integer that is "
|
||||
"equal to number of blacklisted image."),
|
||||
Flag("model_file", ¶ms.model_file_path,
|
||||
"Path to test tflite model file."),
|
||||
};
|
||||
const bool parse_result = Flags::Parse(&argc, argv, flag_list);
|
||||
if (!parse_result)
|
||||
return errors::InvalidArgument("Invalid command line flags");
|
||||
::tensorflow::port::InitMain(argv[0], &argc, &argv);
|
||||
params.number_of_images = 100;
|
||||
std::vector<tflite::Flag> flag_list = {
|
||||
tflite::Flag::CreateFlag(kNumImagesFlag, ¶ms.number_of_images,
|
||||
"Number of examples to evaluate, pass 0 for all "
|
||||
"examples. Default: 100"),
|
||||
tflite::Flag::CreateFlag(
|
||||
kModelOutputLabelsFlag, ¶ms.model_output_labels_path,
|
||||
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||
Env::Default()->IsDirectory(params.ground_truth_images_path),
|
||||
"Invalid ground truth data path.");
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||
Env::Default()->FileExists(params.ground_truth_labels_path),
|
||||
"Invalid ground truth labels path.");
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||
Env::Default()->FileExists(params.model_output_labels_path),
|
||||
"Invalid model output labels path.");
|
||||
"Path to labels that correspond to output of model."
|
||||
" E.g. in case of mobilenet, this is the path to label "
|
||||
"file where each label is in the same order as the output"
|
||||
" of the model."),
|
||||
tflite::Flag::CreateFlag(
|
||||
kGroundTruthImagesPathFlag, ¶ms.ground_truth_images_path,
|
||||
|
||||
if (!params.blacklist_file_path.empty()) {
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||
Env::Default()->FileExists(params.blacklist_file_path),
|
||||
"Invalid blacklist path.");
|
||||
}
|
||||
"Path to ground truth images. These will be evaluated in "
|
||||
"alphabetical order of filename"),
|
||||
tflite::Flag::CreateFlag(
|
||||
kGroundTruthLabelsFlag, ¶ms.ground_truth_labels_path,
|
||||
"Path to ground truth labels, corresponding to alphabetical ordering "
|
||||
"of ground truth images."),
|
||||
tflite::Flag::CreateFlag(
|
||||
kBlacklistFilePathFlag, ¶ms.blacklist_file_path,
|
||||
"Path to blacklist file (optional) where each line is a single "
|
||||
"integer that is "
|
||||
"equal to index number of blacklisted image."),
|
||||
tflite::Flag::CreateFlag(kModelFileFlag, ¶ms.model_file_path,
|
||||
"Path to test tflite model file.")};
|
||||
tflite::Flags::Parse(&argc, const_cast<const char**>(argv), flag_list);
|
||||
|
||||
if (params.number_of_images < 0) {
|
||||
return errors::InvalidArgument("Invalid: num_examples");
|
||||
LOG(ERROR) << "Invalid: num_examples";
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
utils::ModelInfo model_info;
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||
utils::GetTFliteModelInfo(params.model_file_path, &model_info),
|
||||
"Invalid TFLite model.");
|
||||
|
||||
*model_evaluator = absl::make_unique<ImagenetModelEvaluator>(
|
||||
model_info, params, num_threads);
|
||||
return Status::OK();
|
||||
*model_evaluator =
|
||||
absl::make_unique<ImagenetModelEvaluator>(params, num_threads);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
struct ImageLabel {
|
||||
string image;
|
||||
string label;
|
||||
std::string image;
|
||||
std::string label;
|
||||
};
|
||||
|
||||
Status EvaluateModelForShard(const uint64_t shard_id,
|
||||
const std::vector<ImageLabel>& image_labels,
|
||||
const std::vector<string>& model_labels,
|
||||
const utils::ModelInfo& model_info,
|
||||
const ImagenetModelEvaluator::Params& params,
|
||||
ImagenetModelEvaluator::Observer* observer,
|
||||
ImagenetTopKAccuracy* eval) {
|
||||
const TensorShape& input_shape = model_info.input_shapes[0];
|
||||
const int image_height = input_shape.dim_size(1);
|
||||
const int image_width = input_shape.dim_size(2);
|
||||
TfLiteStatus EvaluateModelForShard(const uint64_t shard_id,
|
||||
const std::vector<ImageLabel>& image_labels,
|
||||
const std::vector<std::string>& model_labels,
|
||||
const ImagenetModelEvaluator::Params& params,
|
||||
ImagenetModelEvaluator::Observer* observer,
|
||||
int num_ranks) {
|
||||
tflite::evaluation::EvaluationStageConfig eval_config;
|
||||
eval_config.set_name("image_classification");
|
||||
auto* classification_params = eval_config.mutable_specification()
|
||||
->mutable_image_classification_params();
|
||||
auto* inference_params = classification_params->mutable_inference_params();
|
||||
inference_params->set_model_file_path(params.model_file_path);
|
||||
classification_params->mutable_topk_accuracy_eval_params()->set_k(num_ranks);
|
||||
|
||||
RunTFLiteModelStage::Params tfl_model_params;
|
||||
tfl_model_params.model_file_path = params.model_file_path;
|
||||
|
||||
tfl_model_params.input_type = {model_info.input_types[0]};
|
||||
tfl_model_params.output_type = {model_info.input_types[0]};
|
||||
|
||||
Scope root = Scope::NewRootScope();
|
||||
FileReaderStage reader;
|
||||
InceptionPreprocessingStage inc(image_height, image_width,
|
||||
model_info.input_types[0]);
|
||||
RunTFLiteModelStage tfl_model_stage(tfl_model_params);
|
||||
EvalPipelineBuilder builder;
|
||||
|
||||
std::unique_ptr<EvalPipeline> eval_pipeline;
|
||||
|
||||
auto build_status = builder.WithInputStage(&reader)
|
||||
.WithPreprocessingStage(&inc)
|
||||
.WithRunModelStage(&tfl_model_stage)
|
||||
.WithAccuracyEval(eval)
|
||||
.WithInput("input_file", DT_STRING)
|
||||
.Build(root, &eval_pipeline);
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(build_status,
|
||||
"Failure while building eval pipeline.");
|
||||
std::unique_ptr<Session> session(NewSession(SessionOptions()));
|
||||
|
||||
TF_RETURN_IF_ERROR(eval_pipeline->AttachSession(std::move(session)));
|
||||
tflite::evaluation::ImageClassificationStage eval(eval_config);
|
||||
eval.SetAllLabels(model_labels);
|
||||
TF_LITE_ENSURE_STATUS(eval.Init());
|
||||
|
||||
for (const auto& image_label : image_labels) {
|
||||
TF_CHECK_OK(eval_pipeline->Run(CreateStringTensor(image_label.image),
|
||||
CreateStringTensor(image_label.label)));
|
||||
eval.SetInputs(image_label.image, image_label.label);
|
||||
|
||||
TF_LITE_ENSURE_STATUS(eval.Run());
|
||||
observer->OnSingleImageEvaluationComplete(
|
||||
shard_id, eval->GetTopKAccuracySoFar(), image_label.image);
|
||||
shard_id,
|
||||
eval.LatestMetrics()
|
||||
.process_metrics()
|
||||
.image_classification_metrics()
|
||||
.topk_accuracy_metrics(),
|
||||
image_label.image);
|
||||
}
|
||||
return Status::OK();
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
Status FilterBlackListedImages(const string& blacklist_file_path,
|
||||
std::vector<ImageLabel>* image_labels) {
|
||||
// TODO(b/130823599): Move to tools/evaluation/utils.
|
||||
TfLiteStatus FilterBlackListedImages(const std::string& blacklist_file_path,
|
||||
std::vector<ImageLabel>* image_labels) {
|
||||
if (!blacklist_file_path.empty()) {
|
||||
std::vector<string> lines;
|
||||
TF_RETURN_IF_ERROR(utils::ReadFileLines(blacklist_file_path, &lines));
|
||||
std::vector<std::string> lines;
|
||||
if (!tflite::evaluation::ReadFileLines(blacklist_file_path, &lines)) {
|
||||
LOG(ERROR) << "Could not read: " << blacklist_file_path;
|
||||
return kTfLiteError;
|
||||
}
|
||||
std::vector<int> blacklist_ids;
|
||||
blacklist_ids.reserve(lines.size());
|
||||
// Populate blacklist_ids with indices of images.
|
||||
std::transform(lines.begin(), lines.end(),
|
||||
std::back_inserter(blacklist_ids),
|
||||
[](const string& val) { return std::stoi(val) - 1; });
|
||||
[](const std::string& val) { return std::stoi(val) - 1; });
|
||||
|
||||
std::vector<ImageLabel> filtered_images;
|
||||
std::sort(blacklist_ids.begin(), blacklist_ids.end());
|
||||
@ -251,38 +224,51 @@ Status FilterBlackListedImages(const string& blacklist_file_path,
|
||||
}
|
||||
|
||||
if (filtered_images.size() != size_post_filtering) {
|
||||
return errors::Internal("Invalid number of filtered images");
|
||||
LOG(ERROR) << "Invalid number of filtered images";
|
||||
return kTfLiteError;
|
||||
}
|
||||
*image_labels = filtered_images;
|
||||
}
|
||||
return Status::OK();
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
Status ImagenetModelEvaluator::EvaluateModel() const {
|
||||
if (model_info_.input_shapes.size() != 1) {
|
||||
return errors::InvalidArgument("Invalid input shape");
|
||||
// TODO(b/130823599): Move to tools/evaluation/utils.
|
||||
TfLiteStatus GetSortedFileNames(const std::string dir_path,
|
||||
std::vector<std::string>* result) {
|
||||
DIR* dir;
|
||||
struct dirent* ent;
|
||||
if (result == nullptr) {
|
||||
LOG(ERROR) << "result cannot be nullptr";
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
const TensorShape& input_shape = model_info_.input_shapes[0];
|
||||
// Input should be of the shape {1, height, width, 3}
|
||||
if (input_shape.dims() != 4 || input_shape.dim_size(3) != 3) {
|
||||
return errors::InvalidArgument("Invalid input shape for the model.");
|
||||
if ((dir = opendir(dir_path.c_str())) != nullptr) {
|
||||
while ((ent = readdir(dir)) != nullptr) {
|
||||
std::string filename(std::string(ent->d_name));
|
||||
if (filename.size() <= 2) continue;
|
||||
result->emplace_back(dir_path + "/" + filename);
|
||||
}
|
||||
closedir(dir);
|
||||
} else {
|
||||
LOG(ERROR) << "Could not open dir: " << dir_path;
|
||||
return kTfLiteError;
|
||||
}
|
||||
std::sort(result->begin(), result->end());
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
string data_path =
|
||||
TfLiteStatus ImagenetModelEvaluator::EvaluateModel() const {
|
||||
const std::string data_path =
|
||||
StripTrailingSlashes(params_.ground_truth_images_path) + "/";
|
||||
|
||||
const string imagenet_file_pattern = data_path + kImagenetFilePattern;
|
||||
std::vector<string> image_files;
|
||||
TF_CHECK_OK(
|
||||
Env::Default()->GetMatchingPaths(imagenet_file_pattern, &image_files));
|
||||
std::vector<std::string> image_files;
|
||||
TF_LITE_ENSURE_STATUS(GetSortedFileNames(data_path, &image_files));
|
||||
std::vector<string> ground_truth_image_labels;
|
||||
TF_CHECK_OK(utils::ReadFileLines(params_.ground_truth_labels_path,
|
||||
&ground_truth_image_labels));
|
||||
CHECK_EQ(image_files.size(), ground_truth_image_labels.size());
|
||||
|
||||
// Process files in filename sorted order.
|
||||
std::sort(image_files.begin(), image_files.end());
|
||||
if (!tflite::evaluation::ReadFileLines(params_.ground_truth_labels_path,
|
||||
&ground_truth_image_labels))
|
||||
return kTfLiteError;
|
||||
if (image_files.size() != ground_truth_image_labels.size()) {
|
||||
LOG(ERROR) << "Images and ground truth labels don't match";
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
std::vector<ImageLabel> image_labels;
|
||||
image_labels.reserve(image_files.size());
|
||||
@ -291,56 +277,55 @@ Status ImagenetModelEvaluator::EvaluateModel() const {
|
||||
}
|
||||
|
||||
// Filter any blacklisted images.
|
||||
TF_CHECK_OK(
|
||||
FilterBlackListedImages(params_.blacklist_file_path, &image_labels));
|
||||
if (FilterBlackListedImages(params_.blacklist_file_path, &image_labels) !=
|
||||
kTfLiteOk) {
|
||||
LOG(ERROR) << "Could not filter by blacklist";
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
if (params_.number_of_images > 0) {
|
||||
image_labels = GetFirstN(image_labels, params_.number_of_images);
|
||||
}
|
||||
|
||||
std::vector<string> model_labels;
|
||||
TF_RETURN_IF_ERROR(
|
||||
utils::ReadFileLines(params_.model_output_labels_path, &model_labels));
|
||||
if (model_labels.size() != 1001) {
|
||||
return errors::InvalidArgument("Invalid number of labels: ",
|
||||
model_labels.size());
|
||||
if (!tflite::evaluation::ReadFileLines(params_.model_output_labels_path,
|
||||
&model_labels)) {
|
||||
LOG(ERROR) << "Could not read: " << params_.model_output_labels_path;
|
||||
return kTfLiteError;
|
||||
}
|
||||
if (model_labels.size() != 1001) {
|
||||
LOG(ERROR) << "Invalid number of labels: " << model_labels.size();
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
ImagenetTopKAccuracy eval(model_labels, params_.num_ranks);
|
||||
|
||||
auto img_labels = Split(image_labels, num_threads_);
|
||||
|
||||
BlockingCounter counter(num_threads_);
|
||||
|
||||
CompositeObserver observer(observers_);
|
||||
|
||||
::tensorflow::thread::ThreadPool pool(Env::Default(), "evaluation_pool",
|
||||
num_threads_);
|
||||
std::vector<std::thread> thread_pool;
|
||||
bool all_okay = true;
|
||||
std::unordered_map<uint64_t, int> shard_id_image_count_map;
|
||||
std::vector<std::function<void()>> thread_funcs;
|
||||
thread_funcs.reserve(num_threads_);
|
||||
thread_pool.reserve(num_threads_);
|
||||
for (int i = 0; i < num_threads_; i++) {
|
||||
const auto& image_label = img_labels[i];
|
||||
const uint64_t shard_id = i + 1;
|
||||
shard_id_image_count_map[shard_id] = image_label.size();
|
||||
auto func = [shard_id, &image_label, &model_labels, this, &observer, &eval,
|
||||
&counter]() {
|
||||
TF_CHECK_OK(EvaluateModelForShard(shard_id, image_label, model_labels,
|
||||
model_info_, params_, &observer,
|
||||
&eval));
|
||||
counter.DecrementCount();
|
||||
auto func = [shard_id, &image_label, &model_labels, this, &observer,
|
||||
&all_okay]() {
|
||||
if (EvaluateModelForShard(shard_id, image_label, model_labels, params_,
|
||||
&observer, params_.num_ranks) != kTfLiteOk) {
|
||||
all_okay = all_okay && false;
|
||||
}
|
||||
};
|
||||
thread_funcs.push_back(func);
|
||||
thread_pool.push_back(std::thread(func));
|
||||
}
|
||||
|
||||
observer.OnEvaluationStart(shard_id_image_count_map);
|
||||
for (const auto& func : thread_funcs) {
|
||||
pool.Schedule(func);
|
||||
for (auto& thread : thread_pool) {
|
||||
thread.join();
|
||||
}
|
||||
|
||||
counter.Wait();
|
||||
|
||||
return Status::OK();
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace metrics
|
||||
|
@ -13,15 +13,14 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_LITE_TOOLS_ACCURACY_IMAGENET_MODEL_EVALUATOR_H_
|
||||
#define TENSORFLOW_LITE_TOOLS_ACCURACY_IMAGENET_MODEL_EVALUATOR_H_
|
||||
#ifndef TENSORFLOW_LITE_TOOLS_ACCURACY_ILSVRC_IMAGENET_MODEL_EVALUATOR_H_
|
||||
#define TENSORFLOW_LITE_TOOLS_ACCURACY_ILSVRC_IMAGENET_MODEL_EVALUATOR_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h"
|
||||
#include "tensorflow/lite/tools/accuracy/utils.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace metrics {
|
||||
@ -42,26 +41,26 @@ class ImagenetModelEvaluator {
|
||||
public:
|
||||
struct Params {
|
||||
// Path to ground truth images.
|
||||
string ground_truth_images_path;
|
||||
std::string ground_truth_images_path;
|
||||
|
||||
// Path to labels file for ground truth image.
|
||||
// This file should be generated with the scripts.
|
||||
string ground_truth_labels_path;
|
||||
std::string ground_truth_labels_path;
|
||||
|
||||
// This is word labels generated by the model. The category
|
||||
// indices of output probabilities generated by the model maybe different
|
||||
// from the indices in the imagenet dataset.
|
||||
string model_output_labels_path;
|
||||
std::string model_output_labels_path;
|
||||
|
||||
// Path to the model file.
|
||||
string model_file_path;
|
||||
std::string model_file_path;
|
||||
|
||||
// Path to black list file. 1762 images were blacklisted from
|
||||
// original ILSVRC dataset. This black list file is present in
|
||||
// ILSVRC2014 devkit. Please refer to readme.txt of the ILSVRC2014
|
||||
// devkit for details.
|
||||
// This file is a list of image indices in a sorted order.
|
||||
string blacklist_file_path;
|
||||
std::string blacklist_file_path;
|
||||
|
||||
// The maximum number of images to calculate accuracy.
|
||||
// 0 means all images, a positive number means only the specified
|
||||
@ -90,19 +89,20 @@ class ImagenetModelEvaluator {
|
||||
|
||||
// Called when evaluation was complete for `image`.
|
||||
virtual void OnSingleImageEvaluationComplete(
|
||||
uint64_t shard_id, const ImagenetTopKAccuracy::AccuracyStats& stats,
|
||||
uint64_t shard_id,
|
||||
const tflite::evaluation::TopkAccuracyEvalMetrics& metrics,
|
||||
const string& image) = 0;
|
||||
|
||||
virtual ~Observer() = default;
|
||||
};
|
||||
|
||||
ImagenetModelEvaluator(const utils::ModelInfo& model_info,
|
||||
const Params& params, const int num_threads)
|
||||
: model_info_(model_info), params_(params), num_threads_(num_threads) {}
|
||||
ImagenetModelEvaluator(const Params& params, const int num_threads)
|
||||
: params_(params), num_threads_(num_threads) {}
|
||||
|
||||
// Factory method to create the evaluator by parsing command line arguments.
|
||||
static Status Create(int argc, char* argv[], int num_threads,
|
||||
std::unique_ptr<ImagenetModelEvaluator>* evaluator);
|
||||
static TfLiteStatus Create(
|
||||
int argc, char* argv[], int num_threads,
|
||||
std::unique_ptr<ImagenetModelEvaluator>* evaluator);
|
||||
|
||||
// Adds an observer that can observe evaluation events..
|
||||
void AddObserver(Observer* observer) { observers_.push_back(observer); }
|
||||
@ -110,10 +110,9 @@ class ImagenetModelEvaluator {
|
||||
const Params& params() const { return params_; }
|
||||
|
||||
// Evaluates the provided model over the dataset.
|
||||
Status EvaluateModel() const;
|
||||
TfLiteStatus EvaluateModel() const;
|
||||
|
||||
private:
|
||||
const utils::ModelInfo model_info_;
|
||||
const Params params_;
|
||||
const int num_threads_;
|
||||
std::vector<Observer*> observers_;
|
||||
|
@ -6,8 +6,7 @@ licenses(["notice"]) # Apache 2.0
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
|
||||
load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite")
|
||||
load("//tensorflow/lite:build_def.bzl", "tflite_copts")
|
||||
load("//tensorflow/lite:build_def.bzl", "tflite_linkopts")
|
||||
load("//tensorflow/lite:build_def.bzl", "tflite_copts", "tflite_linkopts")
|
||||
|
||||
common_copts = ["-Wall"] + tflite_copts()
|
||||
|
||||
@ -70,31 +69,9 @@ cc_test(
|
||||
],
|
||||
deps = [
|
||||
":benchmark_tflite_model_lib",
|
||||
":command_line_flags",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/testing:util",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "command_line_flags",
|
||||
srcs = ["command_line_flags.cc"],
|
||||
hdrs = ["command_line_flags.h"],
|
||||
copts = common_copts,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "command_line_flags_test",
|
||||
srcs = ["command_line_flags_test.cc"],
|
||||
copts = common_copts,
|
||||
tags = [
|
||||
"tflite_not_portable_ios", # TODO(b/117786830)
|
||||
],
|
||||
visibility = ["//visibility:private"],
|
||||
deps = [
|
||||
":command_line_flags",
|
||||
"//tensorflow/lite/testing:util",
|
||||
"//tensorflow/lite/tools:command_line_flags",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
@ -143,11 +120,11 @@ cc_library(
|
||||
copts = common_copts,
|
||||
deps = [
|
||||
":benchmark_params",
|
||||
":command_line_flags",
|
||||
":logging",
|
||||
"//tensorflow/core:stats_calculator_portable",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/profiling:time",
|
||||
"//tensorflow/lite/tools:command_line_flags",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -23,9 +23,9 @@ limitations under the License.
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/tools/benchmark/benchmark_params.h"
|
||||
#include "tensorflow/lite/tools/benchmark/command_line_flags.h"
|
||||
#include "tensorflow/core/util/stats_calculator.h"
|
||||
#include "tensorflow/lite/tools/benchmark/benchmark_params.h"
|
||||
#include "tensorflow/lite/tools/command_line_flags.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace benchmark {
|
||||
|
@ -21,7 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/testing/util.h"
|
||||
#include "tensorflow/lite/tools/benchmark/benchmark_tflite_model.h"
|
||||
#include "tensorflow/lite/tools/benchmark/command_line_flags.h"
|
||||
#include "tensorflow/lite/tools/command_line_flags.h"
|
||||
|
||||
namespace {
|
||||
const std::string* g_model_path = nullptr;
|
||||
|
@ -10,7 +10,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/tools/benchmark/command_line_flags.h"
|
||||
#include "tensorflow/lite/tools/command_line_flags.h"
|
||||
|
||||
#include <cstring>
|
||||
#include <sstream>
|
@ -13,7 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/tools/benchmark/command_line_flags.h"
|
||||
#include "tensorflow/lite/tools/command_line_flags.h"
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/testing/util.h"
|
Loading…
Reference in New Issue
Block a user