Internal Change
PiperOrigin-RevId: 241946326
This commit is contained in:
parent
30bcf6ae7c
commit
03df1dd502
@ -17,7 +17,11 @@ package(default_visibility = ["//visibility:public"])
|
|||||||
|
|
||||||
licenses(["notice"]) # Apache 2.0
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
load("//tensorflow/lite:build_def.bzl", "tflite_copts")
|
load("//tensorflow/lite:build_def.bzl", "tflite_copts", "tflite_linkopts")
|
||||||
|
|
||||||
|
exports_files(glob([
|
||||||
|
"testdata/**",
|
||||||
|
]))
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "evaluation_stage",
|
name = "evaluation_stage",
|
||||||
@ -28,3 +32,25 @@ cc_library(
|
|||||||
"//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto",
|
"//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "utils",
|
||||||
|
srcs = ["utils.cc"],
|
||||||
|
hdrs = ["utils.h"],
|
||||||
|
copts = tflite_copts(),
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core:tflite_portable_logging",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "utils_test",
|
||||||
|
srcs = ["utils_test.cc"],
|
||||||
|
data = ["testdata/labels.txt"],
|
||||||
|
linkopts = tflite_linkopts(),
|
||||||
|
linkstatic = 1,
|
||||||
|
deps = [
|
||||||
|
":utils",
|
||||||
|
"@com_google_googletest//:gtest_main",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@ -19,12 +19,13 @@ package tflite.evaluation;
|
|||||||
|
|
||||||
// Defines the functionality executed by an EvaluationStage.
|
// Defines the functionality executed by an EvaluationStage.
|
||||||
//
|
//
|
||||||
// Next ID: 4
|
// Next ID: 5
|
||||||
message ProcessSpecification {
|
message ProcessSpecification {
|
||||||
oneof params {
|
oneof params {
|
||||||
ImagePreprocessingParams image_preprocessing_params = 1;
|
ImagePreprocessingParams image_preprocessing_params = 1;
|
||||||
TopkAccuracyEvalParams topk_accuracy_eval_params = 2;
|
TopkAccuracyEvalParams topk_accuracy_eval_params = 2;
|
||||||
TfliteInferenceParams tflite_inference_params = 3;
|
TfliteInferenceParams tflite_inference_params = 3;
|
||||||
|
ImageClassificationParams image_classification_params = 4;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -48,13 +49,14 @@ message LatencyMetrics {
|
|||||||
// Contains process-specific metrics, which may differ based on what an
|
// Contains process-specific metrics, which may differ based on what an
|
||||||
// EvaluationStage does.
|
// EvaluationStage does.
|
||||||
//
|
//
|
||||||
// Next ID: 4
|
// Next ID: 5
|
||||||
message ProcessMetrics {
|
message ProcessMetrics {
|
||||||
optional LatencyMetrics total_latency = 1;
|
optional LatencyMetrics total_latency = 1;
|
||||||
|
|
||||||
oneof stage_metrics {
|
oneof stage_metrics {
|
||||||
TopkAccuracyEvalMetrics topk_accuracy_metrics = 2;
|
TopkAccuracyEvalMetrics topk_accuracy_metrics = 2;
|
||||||
TfliteInferenceMetrics tflite_inference_metrics = 3;
|
TfliteInferenceMetrics tflite_inference_metrics = 3;
|
||||||
|
ImageClassificationMetrics image_classification_metrics = 4;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -89,8 +91,7 @@ message TfliteInferenceParams {
|
|||||||
}
|
}
|
||||||
optional Delegate delegate = 2;
|
optional Delegate delegate = 2;
|
||||||
// Number of threads available to the TFLite Interpreter.
|
// Number of threads available to the TFLite Interpreter.
|
||||||
// -1 defaults to best estimate.
|
optional int32 num_threads = 3 [default = 1];
|
||||||
optional int32 num_threads = 3 [default = -1];
|
|
||||||
|
|
||||||
// Defines how many times the TFLite Interpreter is invoked for every input.
|
// Defines how many times the TFLite Interpreter is invoked for every input.
|
||||||
// This helps benchmark cases where extensive pre-processing might not be
|
// This helps benchmark cases where extensive pre-processing might not be
|
||||||
@ -121,8 +122,36 @@ message TopkAccuracyEvalMetrics {
|
|||||||
// A repeated field of size |k| where the ith element denotes the fraction of
|
// A repeated field of size |k| where the ith element denotes the fraction of
|
||||||
// samples for which the correct label was present in the top (i + 1) model
|
// samples for which the correct label was present in the top (i + 1) model
|
||||||
// outputs.
|
// outputs.
|
||||||
// For example, topk_accuracy_percentages(1) will contain the fraction of
|
// For example, topk_accuracies(1) will contain the fraction of
|
||||||
// samples for which the model returned the correct label as the top first or
|
// samples for which the model returned the correct label as the top first or
|
||||||
// second output.
|
// second output.
|
||||||
repeated float topk_accuracy_percentages = 1;
|
repeated float topk_accuracies = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parameters that define how the Image Classification task is evaluated
|
||||||
|
// end-to-end.
|
||||||
|
//
|
||||||
|
// Next ID: 3
|
||||||
|
message ImageClassificationParams {
|
||||||
|
// Required.
|
||||||
|
// TfLite model should have 1 input & 1 output tensor.
|
||||||
|
// Input shape: {1, image_height, image_width, 3}
|
||||||
|
// Output shape: {1, num_total_labels}
|
||||||
|
optional TfliteInferenceParams inference_params = 1;
|
||||||
|
|
||||||
|
// Optional.
|
||||||
|
// If not set, accuracy evaluation is not performed.
|
||||||
|
optional TopkAccuracyEvalParams topk_accuracy_eval_params = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Metrics from evaluation of the image classification task.
|
||||||
|
//
|
||||||
|
// Next ID: 5
|
||||||
|
message ImageClassificationMetrics {
|
||||||
|
optional LatencyMetrics pre_processing_latency = 1;
|
||||||
|
optional LatencyMetrics inference_latency = 2;
|
||||||
|
optional TfliteInferenceMetrics inference_metrics = 3;
|
||||||
|
// Not set if topk_accuracy_eval_params was not populated in
|
||||||
|
// ImageClassificationParams.
|
||||||
|
optional TopkAccuracyEvalMetrics topk_accuracy_metrics = 4;
|
||||||
}
|
}
|
||||||
|
@ -127,3 +127,19 @@ cc_test(
|
|||||||
"@com_google_googletest//:gtest_main",
|
"@com_google_googletest//:gtest_main",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "image_classification_stage",
|
||||||
|
srcs = ["image_classification_stage.cc"],
|
||||||
|
hdrs = ["image_classification_stage.h"],
|
||||||
|
copts = tflite_copts(),
|
||||||
|
deps = [
|
||||||
|
":image_preprocessing_stage",
|
||||||
|
":tflite_inference_stage",
|
||||||
|
":topk_accuracy_eval_stage",
|
||||||
|
"//tensorflow/core:tflite_portable_logging",
|
||||||
|
"//tensorflow/lite/tools/evaluation:evaluation_stage",
|
||||||
|
"//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto",
|
||||||
|
"//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@ -0,0 +1,142 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "tensorflow/lite/tools/evaluation/stages/image_classification_stage.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
#include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
|
||||||
|
#include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace evaluation {
|
||||||
|
|
||||||
|
TfLiteStatus ImageClassificationStage::Init() {
|
||||||
|
// Ensure inference params are provided.
|
||||||
|
if (!config_.specification().has_image_classification_params()) {
|
||||||
|
LOG(ERROR) << "ImageClassificationParams not provided";
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
auto& params = config_.specification().image_classification_params();
|
||||||
|
if (!params.has_inference_params()) {
|
||||||
|
LOG(ERROR) << "Inference_params not provided";
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TfliteInferenceStage.
|
||||||
|
EvaluationStageConfig tflite_inference_config;
|
||||||
|
tflite_inference_config.set_name("tflite_inference");
|
||||||
|
*tflite_inference_config.mutable_specification()
|
||||||
|
->mutable_tflite_inference_params() = params.inference_params();
|
||||||
|
inference_stage_.reset(new TfliteInferenceStage(tflite_inference_config));
|
||||||
|
if (inference_stage_->Init() != kTfLiteOk) return kTfLiteError;
|
||||||
|
|
||||||
|
// Validate model inputs.
|
||||||
|
const TfLiteModelInfo* model_info = inference_stage_->GetModelInfo();
|
||||||
|
if (model_info->inputs.size() != 1 || model_info->outputs.size() != 1) {
|
||||||
|
LOG(ERROR) << "Model must have 1 input & 1 output";
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
TfLiteType input_type = model_info->inputs[0]->type;
|
||||||
|
auto* input_shape = model_info->inputs[0]->dims;
|
||||||
|
// Input should be of the shape {1, height, width, 3}
|
||||||
|
if (input_shape->size != 4 || input_shape->data[0] != 1 ||
|
||||||
|
input_shape->data[3] != 3) {
|
||||||
|
LOG(ERROR) << "Invalid input shape for model";
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ImagePreprocessingStage
|
||||||
|
EvaluationStageConfig preprocessing_config;
|
||||||
|
preprocessing_config.set_name("image_preprocessing");
|
||||||
|
auto* preprocess_params = preprocessing_config.mutable_specification()
|
||||||
|
->mutable_image_preprocessing_params();
|
||||||
|
preprocess_params->set_image_height(input_shape->data[1]);
|
||||||
|
preprocess_params->set_image_width(input_shape->data[2]);
|
||||||
|
preprocess_params->set_output_type(static_cast<int>(input_type));
|
||||||
|
preprocessing_stage_.reset(new ImagePreprocessingStage(preprocessing_config));
|
||||||
|
if (preprocessing_stage_->Init() != kTfLiteOk) return kTfLiteError;
|
||||||
|
|
||||||
|
// TopkAccuracyEvalStage.
|
||||||
|
if (params.has_topk_accuracy_eval_params()) {
|
||||||
|
EvaluationStageConfig topk_accuracy_eval_config;
|
||||||
|
topk_accuracy_eval_config.set_name("topk_accuracy");
|
||||||
|
*topk_accuracy_eval_config.mutable_specification()
|
||||||
|
->mutable_topk_accuracy_eval_params() =
|
||||||
|
params.topk_accuracy_eval_params();
|
||||||
|
if (!all_labels_) {
|
||||||
|
LOG(ERROR) << "all_labels not set for TopkAccuracyEvalStage";
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
accuracy_eval_stage_.reset(
|
||||||
|
new TopkAccuracyEvalStage(topk_accuracy_eval_config));
|
||||||
|
accuracy_eval_stage_->SetTaskInfo(*all_labels_, input_type,
|
||||||
|
model_info->outputs[0]->dims);
|
||||||
|
if (accuracy_eval_stage_->Init() != kTfLiteOk) return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus ImageClassificationStage::Run() {
|
||||||
|
if (image_path_.empty()) {
|
||||||
|
LOG(ERROR) << "Input image not set";
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Preprocessing.
|
||||||
|
preprocessing_stage_->SetImagePath(&image_path_);
|
||||||
|
if (preprocessing_stage_->Run() != kTfLiteOk) return kTfLiteError;
|
||||||
|
// Inference.
|
||||||
|
std::vector<void*> data_ptrs = {};
|
||||||
|
data_ptrs.push_back(preprocessing_stage_->GetPreprocessedImageData());
|
||||||
|
inference_stage_->SetInputs(data_ptrs);
|
||||||
|
if (inference_stage_->Run() != kTfLiteOk) return kTfLiteError;
|
||||||
|
// Accuracy Eval.
|
||||||
|
if (accuracy_eval_stage_) {
|
||||||
|
if (ground_truth_label_.empty()) {
|
||||||
|
LOG(ERROR) << "Ground truth label not provided";
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
accuracy_eval_stage_->SetEvalInputs(inference_stage_->GetOutputs()->at(0),
|
||||||
|
&ground_truth_label_);
|
||||||
|
if (accuracy_eval_stage_->Run() != kTfLiteOk) return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
EvaluationStageMetrics ImageClassificationStage::LatestMetrics() {
|
||||||
|
EvaluationStageMetrics metrics;
|
||||||
|
auto* classification_metrics =
|
||||||
|
metrics.mutable_process_metrics()->mutable_image_classification_metrics();
|
||||||
|
|
||||||
|
*classification_metrics->mutable_pre_processing_latency() =
|
||||||
|
preprocessing_stage_->LatestMetrics().process_metrics().total_latency();
|
||||||
|
EvaluationStageMetrics inference_metrics = inference_stage_->LatestMetrics();
|
||||||
|
*classification_metrics->mutable_inference_latency() =
|
||||||
|
inference_metrics.process_metrics().total_latency();
|
||||||
|
*classification_metrics->mutable_inference_metrics() =
|
||||||
|
inference_metrics.process_metrics().tflite_inference_metrics();
|
||||||
|
if (accuracy_eval_stage_) {
|
||||||
|
*classification_metrics->mutable_topk_accuracy_metrics() =
|
||||||
|
accuracy_eval_stage_->LatestMetrics()
|
||||||
|
.process_metrics()
|
||||||
|
.topk_accuracy_metrics();
|
||||||
|
}
|
||||||
|
metrics.set_num_runs(inference_metrics.num_runs());
|
||||||
|
return metrics;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace evaluation
|
||||||
|
} // namespace tflite
|
@ -0,0 +1,73 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_LITE_TOOLS_EVALUATION_STAGES_IMAGE_CLASSIFICATION_STAGE_H_
|
||||||
|
#define TENSORFLOW_LITE_TOOLS_EVALUATION_STAGES_IMAGE_CLASSIFICATION_STAGE_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/tools/evaluation/evaluation_stage.h"
|
||||||
|
#include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
|
||||||
|
#include "tensorflow/lite/tools/evaluation/stages/image_preprocessing_stage.h"
|
||||||
|
#include "tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.h"
|
||||||
|
#include "tensorflow/lite/tools/evaluation/stages/topk_accuracy_eval_stage.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace evaluation {
|
||||||
|
|
||||||
|
// An EvaluationStage to encapsulate the complete Image Classification task.
|
||||||
|
// Utilizes ImagePreprocessingStage, TfLiteInferenceStage &
|
||||||
|
// TopkAccuracyEvalStage for individual sub-tasks.
|
||||||
|
class ImageClassificationStage : public EvaluationStage {
|
||||||
|
public:
|
||||||
|
explicit ImageClassificationStage(const EvaluationStageConfig& config)
|
||||||
|
: EvaluationStage(config) {}
|
||||||
|
|
||||||
|
TfLiteStatus Init() override;
|
||||||
|
|
||||||
|
TfLiteStatus Run() override;
|
||||||
|
|
||||||
|
EvaluationStageMetrics LatestMetrics() override;
|
||||||
|
|
||||||
|
// Call before Init(), if topk_accuracy_eval_params is set in
|
||||||
|
// ImageClassificationParams. all_labels should contain the labels
|
||||||
|
// corresponding to model's output, in the same order. all_labels should
|
||||||
|
// outlive the call to Init().
|
||||||
|
void SetAllLabels(const std::vector<std::string>& all_labels) {
|
||||||
|
all_labels_ = &all_labels;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call before Run().
|
||||||
|
// If accuracy eval is not being performed, ground_truth_label is ignored.
|
||||||
|
void SetInputs(const std::string& image_path,
|
||||||
|
const std::string& ground_truth_label) {
|
||||||
|
image_path_ = image_path;
|
||||||
|
ground_truth_label_ = ground_truth_label;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const std::vector<std::string>* all_labels_ = nullptr;
|
||||||
|
std::unique_ptr<ImagePreprocessingStage> preprocessing_stage_;
|
||||||
|
std::unique_ptr<TfliteInferenceStage> inference_stage_;
|
||||||
|
std::unique_ptr<TopkAccuracyEvalStage> accuracy_eval_stage_;
|
||||||
|
std::string image_path_;
|
||||||
|
std::string ground_truth_label_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace evaluation
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_TOOLS_EVALUATION_STAGES_IMAGE_CLASSIFICATION_STAGE_H_
|
@ -121,8 +121,7 @@ EvaluationStageMetrics TopkAccuracyEvalStage::LatestMetrics() {
|
|||||||
auto* topk_metrics =
|
auto* topk_metrics =
|
||||||
metrics.mutable_process_metrics()->mutable_topk_accuracy_metrics();
|
metrics.mutable_process_metrics()->mutable_topk_accuracy_metrics();
|
||||||
for (const auto& count : accuracy_counts_) {
|
for (const auto& count : accuracy_counts_) {
|
||||||
topk_metrics->add_topk_accuracy_percentages(static_cast<float>(count) /
|
topk_metrics->add_topk_accuracies(static_cast<float>(count) / num_runs_);
|
||||||
num_runs_);
|
|
||||||
}
|
}
|
||||||
return metrics;
|
return metrics;
|
||||||
}
|
}
|
||||||
|
@ -217,9 +217,9 @@ TEST(TopkAccuracyEvalStage, FloatTest_CorrectLabelsAtLastIndices) {
|
|||||||
EXPECT_EQ(1, metrics.num_runs());
|
EXPECT_EQ(1, metrics.num_runs());
|
||||||
auto accuracy_metrics = metrics.process_metrics().topk_accuracy_metrics();
|
auto accuracy_metrics = metrics.process_metrics().topk_accuracy_metrics();
|
||||||
// Only top-5 count is 1.0, rest are 0.0
|
// Only top-5 count is 1.0, rest are 0.0
|
||||||
EXPECT_FLOAT_EQ(1.0, accuracy_metrics.topk_accuracy_percentages(4));
|
EXPECT_FLOAT_EQ(1.0, accuracy_metrics.topk_accuracies(4));
|
||||||
for (int i = 0; i < 4; ++i) {
|
for (int i = 0; i < 4; ++i) {
|
||||||
EXPECT_FLOAT_EQ(0.0, accuracy_metrics.topk_accuracy_percentages(i));
|
EXPECT_FLOAT_EQ(0.0, accuracy_metrics.topk_accuracies(i));
|
||||||
}
|
}
|
||||||
|
|
||||||
// The ground truth is index 1, but it is 4th highest based on model's output.
|
// The ground truth is index 1, but it is 4th highest based on model's output.
|
||||||
@ -231,10 +231,10 @@ TEST(TopkAccuracyEvalStage, FloatTest_CorrectLabelsAtLastIndices) {
|
|||||||
accuracy_metrics = metrics.process_metrics().topk_accuracy_metrics();
|
accuracy_metrics = metrics.process_metrics().topk_accuracy_metrics();
|
||||||
// 1/2 images had the currect output in top-4, 2/2 has currect output in
|
// 1/2 images had the currect output in top-4, 2/2 has currect output in
|
||||||
// top-5.
|
// top-5.
|
||||||
EXPECT_FLOAT_EQ(1.0, accuracy_metrics.topk_accuracy_percentages(4));
|
EXPECT_FLOAT_EQ(1.0, accuracy_metrics.topk_accuracies(4));
|
||||||
EXPECT_FLOAT_EQ(0.5, accuracy_metrics.topk_accuracy_percentages(3));
|
EXPECT_FLOAT_EQ(0.5, accuracy_metrics.topk_accuracies(3));
|
||||||
for (int i = 0; i < 3; ++i) {
|
for (int i = 0; i < 3; ++i) {
|
||||||
EXPECT_FLOAT_EQ(0.0, accuracy_metrics.topk_accuracy_percentages(i));
|
EXPECT_FLOAT_EQ(0.0, accuracy_metrics.topk_accuracies(i));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -261,7 +261,7 @@ class CorrectTopkAccuracyEvalTest : public ::testing::Test {
|
|||||||
EvaluationStageMetrics metrics = stage.LatestMetrics();
|
EvaluationStageMetrics metrics = stage.LatestMetrics();
|
||||||
EXPECT_EQ(0, metrics.num_runs());
|
EXPECT_EQ(0, metrics.num_runs());
|
||||||
auto accuracy_metrics = metrics.process_metrics().topk_accuracy_metrics();
|
auto accuracy_metrics = metrics.process_metrics().topk_accuracy_metrics();
|
||||||
EXPECT_EQ(0, accuracy_metrics.topk_accuracy_percentages_size());
|
EXPECT_EQ(0, accuracy_metrics.topk_accuracies_size());
|
||||||
|
|
||||||
T array[kNumCategories];
|
T array[kNumCategories];
|
||||||
|
|
||||||
@ -274,9 +274,8 @@ class CorrectTopkAccuracyEvalTest : public ::testing::Test {
|
|||||||
metrics = stage.LatestMetrics();
|
metrics = stage.LatestMetrics();
|
||||||
EXPECT_EQ(1, metrics.num_runs());
|
EXPECT_EQ(1, metrics.num_runs());
|
||||||
accuracy_metrics = metrics.process_metrics().topk_accuracy_metrics();
|
accuracy_metrics = metrics.process_metrics().topk_accuracy_metrics();
|
||||||
for (int i = 0; i < accuracy_metrics.topk_accuracy_percentages_size();
|
for (int i = 0; i < accuracy_metrics.topk_accuracies_size(); ++i) {
|
||||||
++i) {
|
EXPECT_FLOAT_EQ(1.0, accuracy_metrics.topk_accuracies(i));
|
||||||
EXPECT_FLOAT_EQ(1.0, accuracy_metrics.topk_accuracy_percentages(i));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Second image was also correctly identified as "1".
|
// Second image was also correctly identified as "1".
|
||||||
@ -288,9 +287,8 @@ class CorrectTopkAccuracyEvalTest : public ::testing::Test {
|
|||||||
metrics = stage.LatestMetrics();
|
metrics = stage.LatestMetrics();
|
||||||
EXPECT_EQ(2, metrics.num_runs());
|
EXPECT_EQ(2, metrics.num_runs());
|
||||||
accuracy_metrics = metrics.process_metrics().topk_accuracy_metrics();
|
accuracy_metrics = metrics.process_metrics().topk_accuracy_metrics();
|
||||||
for (int i = 0; i < accuracy_metrics.topk_accuracy_percentages_size();
|
for (int i = 0; i < accuracy_metrics.topk_accuracies_size(); ++i) {
|
||||||
++i) {
|
EXPECT_FLOAT_EQ(1.0, accuracy_metrics.topk_accuracies(i));
|
||||||
EXPECT_FLOAT_EQ(1.0, accuracy_metrics.topk_accuracy_percentages(i));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
2
tensorflow/lite/tools/evaluation/testdata/labels.txt
vendored
Normal file
2
tensorflow/lite/tools/evaluation/testdata/labels.txt
vendored
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
label1
|
||||||
|
label2
|
48
tensorflow/lite/tools/evaluation/utils.cc
Normal file
48
tensorflow/lite/tools/evaluation/utils.cc
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/tools/evaluation/utils.h"
|
||||||
|
|
||||||
|
#include <sys/stat.h>
|
||||||
|
|
||||||
|
#include <fstream>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace evaluation {
|
||||||
|
|
||||||
|
bool ReadFileLines(const std::string& file_path,
|
||||||
|
std::vector<std::string>* lines_output) {
|
||||||
|
if (!lines_output) {
|
||||||
|
LOG(ERROR) << "lines_output is null";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
std::ifstream stream(file_path.c_str());
|
||||||
|
if (!stream) {
|
||||||
|
LOG(ERROR) << "Unable to open file: " << file_path;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
std::string line;
|
||||||
|
while (std::getline(stream, line)) {
|
||||||
|
lines_output->push_back(line);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace evaluation
|
||||||
|
} // namespace tflite
|
29
tensorflow/lite/tools/evaluation/utils.h
Normal file
29
tensorflow/lite/tools/evaluation/utils.h
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_TOOLS_EVALUATION_UTILS_H_
|
||||||
|
#define TENSORFLOW_LITE_TOOLS_EVALUATION_UTILS_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace evaluation {
|
||||||
|
bool ReadFileLines(const std::string& file_path,
|
||||||
|
std::vector<std::string>* lines_output);
|
||||||
|
} // namespace evaluation
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_TOOLS_EVALUATION_UTILS_H_
|
49
tensorflow/lite/tools/evaluation/utils_test.cc
Normal file
49
tensorflow/lite/tools/evaluation/utils_test.cc
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "tensorflow/lite/tools/evaluation/utils.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace evaluation {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
constexpr char kFilePath[] =
|
||||||
|
"tensorflow/lite/tools/evaluation/testdata/labels.txt";
|
||||||
|
|
||||||
|
TEST(UtilsTest, ReadFileErrors) {
|
||||||
|
std::string correct_path(kFilePath);
|
||||||
|
std::string wrong_path("xyz.txt");
|
||||||
|
std::vector<std::string> lines;
|
||||||
|
EXPECT_FALSE(ReadFileLines(correct_path, nullptr));
|
||||||
|
EXPECT_FALSE(ReadFileLines(wrong_path, &lines));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(UtilsTest, ReadFileCorrectly) {
|
||||||
|
std::string file_path(kFilePath);
|
||||||
|
std::vector<std::string> lines;
|
||||||
|
EXPECT_TRUE(ReadFileLines(file_path, &lines));
|
||||||
|
|
||||||
|
EXPECT_EQ(lines.size(), 2);
|
||||||
|
EXPECT_EQ(lines[0], "label1");
|
||||||
|
EXPECT_EQ(lines[1], "label2");
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace evaluation
|
||||||
|
} // namespace tflite
|
Loading…
Reference in New Issue
Block a user