Internal Change
PiperOrigin-RevId: 245962899
This commit is contained in:
parent
aa7d77dedc
commit
efcb8bba81
@ -42,14 +42,6 @@ constexpr char kGroundTruthLabelsFlag[] = "ground_truth_labels";
|
||||
constexpr char kBlacklistFilePathFlag[] = "blacklist_file_path";
|
||||
constexpr char kModelFileFlag[] = "model_file";
|
||||
|
||||
std::string StripTrailingSlashes(const std::string& path) {
|
||||
int end = path.size();
|
||||
while (end > 0 && path[end - 1] == '/') {
|
||||
end--;
|
||||
}
|
||||
return path.substr(0, end);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> GetFirstN(const std::vector<T>& v, int n) {
|
||||
if (n >= v.size()) return v;
|
||||
@ -232,35 +224,13 @@ TfLiteStatus FilterBlackListedImages(const std::string& blacklist_file_path,
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
// TODO(b/130823599): Move to tools/evaluation/utils.
|
||||
TfLiteStatus GetSortedFileNames(const std::string dir_path,
|
||||
std::vector<std::string>* result) {
|
||||
DIR* dir;
|
||||
struct dirent* ent;
|
||||
if (result == nullptr) {
|
||||
LOG(ERROR) << "result cannot be nullptr";
|
||||
return kTfLiteError;
|
||||
}
|
||||
if ((dir = opendir(dir_path.c_str())) != nullptr) {
|
||||
while ((ent = readdir(dir)) != nullptr) {
|
||||
std::string filename(std::string(ent->d_name));
|
||||
if (filename.size() <= 2) continue;
|
||||
result->emplace_back(dir_path + "/" + filename);
|
||||
}
|
||||
closedir(dir);
|
||||
} else {
|
||||
LOG(ERROR) << "Could not open dir: " << dir_path;
|
||||
return kTfLiteError;
|
||||
}
|
||||
std::sort(result->begin(), result->end());
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus ImagenetModelEvaluator::EvaluateModel() const {
|
||||
const std::string data_path =
|
||||
StripTrailingSlashes(params_.ground_truth_images_path) + "/";
|
||||
const std::string data_path = tflite::evaluation::StripTrailingSlashes(
|
||||
params_.ground_truth_images_path) +
|
||||
"/";
|
||||
std::vector<std::string> image_files;
|
||||
TF_LITE_ENSURE_STATUS(GetSortedFileNames(data_path, &image_files));
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
tflite::evaluation::GetSortedFileNames(data_path, &image_files));
|
||||
std::vector<string> ground_truth_image_labels;
|
||||
if (!tflite::evaluation::ReadFileLines(params_.ground_truth_labels_path,
|
||||
&ground_truth_image_labels))
|
||||
|
@ -40,6 +40,7 @@ cc_library(
|
||||
copts = tflite_copts(),
|
||||
deps = [
|
||||
"//tensorflow/core:tflite_portable_logging",
|
||||
"//tensorflow/lite:context",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/delegates/nnapi:nnapi_delegate",
|
||||
] + select({
|
||||
@ -53,11 +54,15 @@ cc_library(
|
||||
cc_test(
|
||||
name = "utils_test",
|
||||
srcs = ["utils_test.cc"],
|
||||
data = ["testdata/labels.txt"],
|
||||
data = [
|
||||
"testdata/empty.txt",
|
||||
"testdata/labels.txt",
|
||||
],
|
||||
linkopts = tflite_linkopts(),
|
||||
linkstatic = 1,
|
||||
deps = [
|
||||
":utils",
|
||||
"//tensorflow/lite:context",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
@ -30,6 +30,11 @@ cc_proto_library(
|
||||
deps = ["evaluation_stages_proto"],
|
||||
)
|
||||
|
||||
java_proto_library(
|
||||
name = "evaluation_stages_java_proto",
|
||||
deps = ["evaluation_stages_proto"],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
name = "evaluation_config_proto",
|
||||
srcs = [
|
||||
@ -43,3 +48,8 @@ cc_proto_library(
|
||||
name = "evaluation_config_cc_proto",
|
||||
deps = ["evaluation_config_proto"],
|
||||
)
|
||||
|
||||
java_proto_library(
|
||||
name = "evaluation_config_java_proto",
|
||||
deps = ["evaluation_config_proto"],
|
||||
)
|
||||
|
@ -19,6 +19,10 @@ package tflite.evaluation;
|
||||
|
||||
import "tensorflow/lite/tools/evaluation/proto/evaluation_stages.proto";
|
||||
|
||||
option cc_enable_arenas = true;
|
||||
option java_multiple_files = true;
|
||||
option java_package = "tflite.evaluation";
|
||||
|
||||
// Contains parameters that define how an EvaluationStage will be executed.
|
||||
// This would typically be validated only once during initialization, so should
|
||||
// not contain any variables that change with each run.
|
||||
|
@ -17,6 +17,10 @@ syntax = "proto2";
|
||||
|
||||
package tflite.evaluation;
|
||||
|
||||
option cc_enable_arenas = true;
|
||||
option java_multiple_files = true;
|
||||
option java_package = "tflite.evaluation";
|
||||
|
||||
// Defines the functionality executed by an EvaluationStage.
|
||||
//
|
||||
// Next ID: 5
|
||||
|
0
tensorflow/lite/tools/evaluation/testdata/empty.txt
vendored
Normal file
0
tensorflow/lite/tools/evaluation/testdata/empty.txt
vendored
Normal file
@ -15,8 +15,10 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/lite/tools/evaluation/utils.h"
|
||||
|
||||
#include <dirent.h>
|
||||
#include <sys/stat.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
@ -31,6 +33,14 @@ limitations under the License.
|
||||
namespace tflite {
|
||||
namespace evaluation {
|
||||
|
||||
std::string StripTrailingSlashes(const std::string& path) {
|
||||
int end = path.size();
|
||||
while (end > 0 && path[end - 1] == '/') {
|
||||
end--;
|
||||
}
|
||||
return path.substr(0, end);
|
||||
}
|
||||
|
||||
bool ReadFileLines(const std::string& file_path,
|
||||
std::vector<std::string>* lines_output) {
|
||||
if (!lines_output) {
|
||||
@ -49,6 +59,31 @@ bool ReadFileLines(const std::string& file_path,
|
||||
return true;
|
||||
}
|
||||
|
||||
TfLiteStatus GetSortedFileNames(const std::string& directory,
|
||||
std::vector<std::string>* result) {
|
||||
DIR* dir;
|
||||
struct dirent* ent;
|
||||
if (result == nullptr) {
|
||||
LOG(ERROR) << "result cannot be nullptr";
|
||||
return kTfLiteError;
|
||||
}
|
||||
result->clear();
|
||||
std::string dir_path = StripTrailingSlashes(directory);
|
||||
if ((dir = opendir(dir_path.c_str())) != nullptr) {
|
||||
while ((ent = readdir(dir)) != nullptr) {
|
||||
std::string filename(std::string(ent->d_name));
|
||||
if (filename.size() <= 2) continue;
|
||||
result->emplace_back(dir_path + "/" + filename);
|
||||
}
|
||||
closedir(dir);
|
||||
} else {
|
||||
LOG(ERROR) << "Could not open dir: " << dir_path;
|
||||
return kTfLiteError;
|
||||
}
|
||||
std::sort(result->begin(), result->end());
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
Interpreter::TfLiteDelegatePtr CreateNNAPIDelegate() {
|
||||
#if defined(__ANDROID__)
|
||||
return Interpreter::TfLiteDelegatePtr(
|
||||
|
@ -19,13 +19,19 @@ limitations under the License.
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/context.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace evaluation {
|
||||
std::string StripTrailingSlashes(const std::string& path);
|
||||
|
||||
bool ReadFileLines(const std::string& file_path,
|
||||
std::vector<std::string>* lines_output);
|
||||
|
||||
TfLiteStatus GetSortedFileNames(const std::string& directory,
|
||||
std::vector<std::string>* result);
|
||||
|
||||
Interpreter::TfLiteDelegatePtr CreateNNAPIDelegate();
|
||||
|
||||
Interpreter::TfLiteDelegatePtr CreateGPUDelegate(FlatBufferModel* model);
|
||||
|
@ -18,16 +18,32 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/context.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace evaluation {
|
||||
namespace {
|
||||
|
||||
constexpr char kFilePath[] =
|
||||
constexpr char kLabelsPath[] =
|
||||
"tensorflow/lite/tools/evaluation/testdata/labels.txt";
|
||||
constexpr char kDirPath[] =
|
||||
"tensorflow/lite/tools/evaluation/testdata";
|
||||
constexpr char kEmptyFilePath[] =
|
||||
"tensorflow/lite/tools/evaluation/testdata/empty.txt";
|
||||
|
||||
TEST(UtilsTest, StripTrailingSlashesTest) {
|
||||
std::string path = "/usr/local/folder/";
|
||||
EXPECT_EQ(StripTrailingSlashes(path), "/usr/local/folder");
|
||||
|
||||
path = "/usr/local/folder";
|
||||
EXPECT_EQ(StripTrailingSlashes(path), path);
|
||||
|
||||
path = "folder";
|
||||
EXPECT_EQ(StripTrailingSlashes(path), path);
|
||||
}
|
||||
|
||||
TEST(UtilsTest, ReadFileErrors) {
|
||||
std::string correct_path(kFilePath);
|
||||
std::string correct_path(kLabelsPath);
|
||||
std::string wrong_path("xyz.txt");
|
||||
std::vector<std::string> lines;
|
||||
EXPECT_FALSE(ReadFileLines(correct_path, nullptr));
|
||||
@ -35,7 +51,7 @@ TEST(UtilsTest, ReadFileErrors) {
|
||||
}
|
||||
|
||||
TEST(UtilsTest, ReadFileCorrectly) {
|
||||
std::string file_path(kFilePath);
|
||||
std::string file_path(kLabelsPath);
|
||||
std::vector<std::string> lines;
|
||||
EXPECT_TRUE(ReadFileLines(file_path, &lines));
|
||||
|
||||
@ -44,6 +60,17 @@ TEST(UtilsTest, ReadFileCorrectly) {
|
||||
EXPECT_EQ(lines[1], "label2");
|
||||
}
|
||||
|
||||
TEST(UtilsTest, SortedFilenamesTest) {
|
||||
std::vector<std::string> files;
|
||||
EXPECT_EQ(GetSortedFileNames(kDirPath, &files), kTfLiteOk);
|
||||
|
||||
EXPECT_EQ(files.size(), 2);
|
||||
EXPECT_EQ(files[0], kEmptyFilePath);
|
||||
EXPECT_EQ(files[1], kLabelsPath);
|
||||
|
||||
EXPECT_EQ(GetSortedFileNames("wrong_path", &files), kTfLiteError);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace evaluation
|
||||
} // namespace tflite
|
||||
|
Loading…
Reference in New Issue
Block a user