Internal Change
PiperOrigin-RevId: 245962899
This commit is contained in:
parent
aa7d77dedc
commit
efcb8bba81
@ -42,14 +42,6 @@ constexpr char kGroundTruthLabelsFlag[] = "ground_truth_labels";
|
|||||||
constexpr char kBlacklistFilePathFlag[] = "blacklist_file_path";
|
constexpr char kBlacklistFilePathFlag[] = "blacklist_file_path";
|
||||||
constexpr char kModelFileFlag[] = "model_file";
|
constexpr char kModelFileFlag[] = "model_file";
|
||||||
|
|
||||||
std::string StripTrailingSlashes(const std::string& path) {
|
|
||||||
int end = path.size();
|
|
||||||
while (end > 0 && path[end - 1] == '/') {
|
|
||||||
end--;
|
|
||||||
}
|
|
||||||
return path.substr(0, end);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
std::vector<T> GetFirstN(const std::vector<T>& v, int n) {
|
std::vector<T> GetFirstN(const std::vector<T>& v, int n) {
|
||||||
if (n >= v.size()) return v;
|
if (n >= v.size()) return v;
|
||||||
@ -232,35 +224,13 @@ TfLiteStatus FilterBlackListedImages(const std::string& blacklist_file_path,
|
|||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(b/130823599): Move to tools/evaluation/utils.
|
|
||||||
TfLiteStatus GetSortedFileNames(const std::string dir_path,
|
|
||||||
std::vector<std::string>* result) {
|
|
||||||
DIR* dir;
|
|
||||||
struct dirent* ent;
|
|
||||||
if (result == nullptr) {
|
|
||||||
LOG(ERROR) << "result cannot be nullptr";
|
|
||||||
return kTfLiteError;
|
|
||||||
}
|
|
||||||
if ((dir = opendir(dir_path.c_str())) != nullptr) {
|
|
||||||
while ((ent = readdir(dir)) != nullptr) {
|
|
||||||
std::string filename(std::string(ent->d_name));
|
|
||||||
if (filename.size() <= 2) continue;
|
|
||||||
result->emplace_back(dir_path + "/" + filename);
|
|
||||||
}
|
|
||||||
closedir(dir);
|
|
||||||
} else {
|
|
||||||
LOG(ERROR) << "Could not open dir: " << dir_path;
|
|
||||||
return kTfLiteError;
|
|
||||||
}
|
|
||||||
std::sort(result->begin(), result->end());
|
|
||||||
return kTfLiteOk;
|
|
||||||
}
|
|
||||||
|
|
||||||
TfLiteStatus ImagenetModelEvaluator::EvaluateModel() const {
|
TfLiteStatus ImagenetModelEvaluator::EvaluateModel() const {
|
||||||
const std::string data_path =
|
const std::string data_path = tflite::evaluation::StripTrailingSlashes(
|
||||||
StripTrailingSlashes(params_.ground_truth_images_path) + "/";
|
params_.ground_truth_images_path) +
|
||||||
|
"/";
|
||||||
std::vector<std::string> image_files;
|
std::vector<std::string> image_files;
|
||||||
TF_LITE_ENSURE_STATUS(GetSortedFileNames(data_path, &image_files));
|
TF_LITE_ENSURE_STATUS(
|
||||||
|
tflite::evaluation::GetSortedFileNames(data_path, &image_files));
|
||||||
std::vector<string> ground_truth_image_labels;
|
std::vector<string> ground_truth_image_labels;
|
||||||
if (!tflite::evaluation::ReadFileLines(params_.ground_truth_labels_path,
|
if (!tflite::evaluation::ReadFileLines(params_.ground_truth_labels_path,
|
||||||
&ground_truth_image_labels))
|
&ground_truth_image_labels))
|
||||||
|
@ -40,6 +40,7 @@ cc_library(
|
|||||||
copts = tflite_copts(),
|
copts = tflite_copts(),
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/core:tflite_portable_logging",
|
"//tensorflow/core:tflite_portable_logging",
|
||||||
|
"//tensorflow/lite:context",
|
||||||
"//tensorflow/lite:framework",
|
"//tensorflow/lite:framework",
|
||||||
"//tensorflow/lite/delegates/nnapi:nnapi_delegate",
|
"//tensorflow/lite/delegates/nnapi:nnapi_delegate",
|
||||||
] + select({
|
] + select({
|
||||||
@ -53,11 +54,15 @@ cc_library(
|
|||||||
cc_test(
|
cc_test(
|
||||||
name = "utils_test",
|
name = "utils_test",
|
||||||
srcs = ["utils_test.cc"],
|
srcs = ["utils_test.cc"],
|
||||||
data = ["testdata/labels.txt"],
|
data = [
|
||||||
|
"testdata/empty.txt",
|
||||||
|
"testdata/labels.txt",
|
||||||
|
],
|
||||||
linkopts = tflite_linkopts(),
|
linkopts = tflite_linkopts(),
|
||||||
linkstatic = 1,
|
linkstatic = 1,
|
||||||
deps = [
|
deps = [
|
||||||
":utils",
|
":utils",
|
||||||
|
"//tensorflow/lite:context",
|
||||||
"@com_google_googletest//:gtest_main",
|
"@com_google_googletest//:gtest_main",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -30,6 +30,11 @@ cc_proto_library(
|
|||||||
deps = ["evaluation_stages_proto"],
|
deps = ["evaluation_stages_proto"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
java_proto_library(
|
||||||
|
name = "evaluation_stages_java_proto",
|
||||||
|
deps = ["evaluation_stages_proto"],
|
||||||
|
)
|
||||||
|
|
||||||
proto_library(
|
proto_library(
|
||||||
name = "evaluation_config_proto",
|
name = "evaluation_config_proto",
|
||||||
srcs = [
|
srcs = [
|
||||||
@ -43,3 +48,8 @@ cc_proto_library(
|
|||||||
name = "evaluation_config_cc_proto",
|
name = "evaluation_config_cc_proto",
|
||||||
deps = ["evaluation_config_proto"],
|
deps = ["evaluation_config_proto"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
java_proto_library(
|
||||||
|
name = "evaluation_config_java_proto",
|
||||||
|
deps = ["evaluation_config_proto"],
|
||||||
|
)
|
||||||
|
@ -19,6 +19,10 @@ package tflite.evaluation;
|
|||||||
|
|
||||||
import "tensorflow/lite/tools/evaluation/proto/evaluation_stages.proto";
|
import "tensorflow/lite/tools/evaluation/proto/evaluation_stages.proto";
|
||||||
|
|
||||||
|
option cc_enable_arenas = true;
|
||||||
|
option java_multiple_files = true;
|
||||||
|
option java_package = "tflite.evaluation";
|
||||||
|
|
||||||
// Contains parameters that define how an EvaluationStage will be executed.
|
// Contains parameters that define how an EvaluationStage will be executed.
|
||||||
// This would typically be validated only once during initialization, so should
|
// This would typically be validated only once during initialization, so should
|
||||||
// not contain any variables that change with each run.
|
// not contain any variables that change with each run.
|
||||||
|
@ -17,6 +17,10 @@ syntax = "proto2";
|
|||||||
|
|
||||||
package tflite.evaluation;
|
package tflite.evaluation;
|
||||||
|
|
||||||
|
option cc_enable_arenas = true;
|
||||||
|
option java_multiple_files = true;
|
||||||
|
option java_package = "tflite.evaluation";
|
||||||
|
|
||||||
// Defines the functionality executed by an EvaluationStage.
|
// Defines the functionality executed by an EvaluationStage.
|
||||||
//
|
//
|
||||||
// Next ID: 5
|
// Next ID: 5
|
||||||
|
0
tensorflow/lite/tools/evaluation/testdata/empty.txt
vendored
Normal file
0
tensorflow/lite/tools/evaluation/testdata/empty.txt
vendored
Normal file
@ -15,8 +15,10 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/lite/tools/evaluation/utils.h"
|
#include "tensorflow/lite/tools/evaluation/utils.h"
|
||||||
|
|
||||||
|
#include <dirent.h>
|
||||||
#include <sys/stat.h>
|
#include <sys/stat.h>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
@ -31,6 +33,14 @@ limitations under the License.
|
|||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace evaluation {
|
namespace evaluation {
|
||||||
|
|
||||||
|
std::string StripTrailingSlashes(const std::string& path) {
|
||||||
|
int end = path.size();
|
||||||
|
while (end > 0 && path[end - 1] == '/') {
|
||||||
|
end--;
|
||||||
|
}
|
||||||
|
return path.substr(0, end);
|
||||||
|
}
|
||||||
|
|
||||||
bool ReadFileLines(const std::string& file_path,
|
bool ReadFileLines(const std::string& file_path,
|
||||||
std::vector<std::string>* lines_output) {
|
std::vector<std::string>* lines_output) {
|
||||||
if (!lines_output) {
|
if (!lines_output) {
|
||||||
@ -49,6 +59,31 @@ bool ReadFileLines(const std::string& file_path,
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TfLiteStatus GetSortedFileNames(const std::string& directory,
|
||||||
|
std::vector<std::string>* result) {
|
||||||
|
DIR* dir;
|
||||||
|
struct dirent* ent;
|
||||||
|
if (result == nullptr) {
|
||||||
|
LOG(ERROR) << "result cannot be nullptr";
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
result->clear();
|
||||||
|
std::string dir_path = StripTrailingSlashes(directory);
|
||||||
|
if ((dir = opendir(dir_path.c_str())) != nullptr) {
|
||||||
|
while ((ent = readdir(dir)) != nullptr) {
|
||||||
|
std::string filename(std::string(ent->d_name));
|
||||||
|
if (filename.size() <= 2) continue;
|
||||||
|
result->emplace_back(dir_path + "/" + filename);
|
||||||
|
}
|
||||||
|
closedir(dir);
|
||||||
|
} else {
|
||||||
|
LOG(ERROR) << "Could not open dir: " << dir_path;
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
std::sort(result->begin(), result->end());
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
Interpreter::TfLiteDelegatePtr CreateNNAPIDelegate() {
|
Interpreter::TfLiteDelegatePtr CreateNNAPIDelegate() {
|
||||||
#if defined(__ANDROID__)
|
#if defined(__ANDROID__)
|
||||||
return Interpreter::TfLiteDelegatePtr(
|
return Interpreter::TfLiteDelegatePtr(
|
||||||
|
@ -19,13 +19,19 @@ limitations under the License.
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/context.h"
|
||||||
#include "tensorflow/lite/model.h"
|
#include "tensorflow/lite/model.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace evaluation {
|
namespace evaluation {
|
||||||
|
std::string StripTrailingSlashes(const std::string& path);
|
||||||
|
|
||||||
bool ReadFileLines(const std::string& file_path,
|
bool ReadFileLines(const std::string& file_path,
|
||||||
std::vector<std::string>* lines_output);
|
std::vector<std::string>* lines_output);
|
||||||
|
|
||||||
|
TfLiteStatus GetSortedFileNames(const std::string& directory,
|
||||||
|
std::vector<std::string>* result);
|
||||||
|
|
||||||
Interpreter::TfLiteDelegatePtr CreateNNAPIDelegate();
|
Interpreter::TfLiteDelegatePtr CreateNNAPIDelegate();
|
||||||
|
|
||||||
Interpreter::TfLiteDelegatePtr CreateGPUDelegate(FlatBufferModel* model);
|
Interpreter::TfLiteDelegatePtr CreateGPUDelegate(FlatBufferModel* model);
|
||||||
|
@ -18,16 +18,32 @@ limitations under the License.
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
#include "tensorflow/lite/context.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace evaluation {
|
namespace evaluation {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
constexpr char kFilePath[] =
|
constexpr char kLabelsPath[] =
|
||||||
"tensorflow/lite/tools/evaluation/testdata/labels.txt";
|
"tensorflow/lite/tools/evaluation/testdata/labels.txt";
|
||||||
|
constexpr char kDirPath[] =
|
||||||
|
"tensorflow/lite/tools/evaluation/testdata";
|
||||||
|
constexpr char kEmptyFilePath[] =
|
||||||
|
"tensorflow/lite/tools/evaluation/testdata/empty.txt";
|
||||||
|
|
||||||
|
TEST(UtilsTest, StripTrailingSlashesTest) {
|
||||||
|
std::string path = "/usr/local/folder/";
|
||||||
|
EXPECT_EQ(StripTrailingSlashes(path), "/usr/local/folder");
|
||||||
|
|
||||||
|
path = "/usr/local/folder";
|
||||||
|
EXPECT_EQ(StripTrailingSlashes(path), path);
|
||||||
|
|
||||||
|
path = "folder";
|
||||||
|
EXPECT_EQ(StripTrailingSlashes(path), path);
|
||||||
|
}
|
||||||
|
|
||||||
TEST(UtilsTest, ReadFileErrors) {
|
TEST(UtilsTest, ReadFileErrors) {
|
||||||
std::string correct_path(kFilePath);
|
std::string correct_path(kLabelsPath);
|
||||||
std::string wrong_path("xyz.txt");
|
std::string wrong_path("xyz.txt");
|
||||||
std::vector<std::string> lines;
|
std::vector<std::string> lines;
|
||||||
EXPECT_FALSE(ReadFileLines(correct_path, nullptr));
|
EXPECT_FALSE(ReadFileLines(correct_path, nullptr));
|
||||||
@ -35,7 +51,7 @@ TEST(UtilsTest, ReadFileErrors) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(UtilsTest, ReadFileCorrectly) {
|
TEST(UtilsTest, ReadFileCorrectly) {
|
||||||
std::string file_path(kFilePath);
|
std::string file_path(kLabelsPath);
|
||||||
std::vector<std::string> lines;
|
std::vector<std::string> lines;
|
||||||
EXPECT_TRUE(ReadFileLines(file_path, &lines));
|
EXPECT_TRUE(ReadFileLines(file_path, &lines));
|
||||||
|
|
||||||
@ -44,6 +60,17 @@ TEST(UtilsTest, ReadFileCorrectly) {
|
|||||||
EXPECT_EQ(lines[1], "label2");
|
EXPECT_EQ(lines[1], "label2");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(UtilsTest, SortedFilenamesTest) {
|
||||||
|
std::vector<std::string> files;
|
||||||
|
EXPECT_EQ(GetSortedFileNames(kDirPath, &files), kTfLiteOk);
|
||||||
|
|
||||||
|
EXPECT_EQ(files.size(), 2);
|
||||||
|
EXPECT_EQ(files[0], kEmptyFilePath);
|
||||||
|
EXPECT_EQ(files[1], kLabelsPath);
|
||||||
|
|
||||||
|
EXPECT_EQ(GetSortedFileNames("wrong_path", &files), kTfLiteError);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace evaluation
|
} // namespace evaluation
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
Loading…
Reference in New Issue
Block a user