Internal Change

PiperOrigin-RevId: 245962899
This commit is contained in:
Sachin Joglekar 2019-04-30 09:31:24 -07:00 committed by TensorFlower Gardener
parent aa7d77dedc
commit efcb8bba81
9 changed files with 100 additions and 39 deletions

View File

@ -42,14 +42,6 @@ constexpr char kGroundTruthLabelsFlag[] = "ground_truth_labels";
constexpr char kBlacklistFilePathFlag[] = "blacklist_file_path";
constexpr char kModelFileFlag[] = "model_file";
std::string StripTrailingSlashes(const std::string& path) {
int end = path.size();
while (end > 0 && path[end - 1] == '/') {
end--;
}
return path.substr(0, end);
}
template <typename T>
std::vector<T> GetFirstN(const std::vector<T>& v, int n) {
if (n >= v.size()) return v;
@ -232,35 +224,13 @@ TfLiteStatus FilterBlackListedImages(const std::string& blacklist_file_path,
return kTfLiteOk;
}
// TODO(b/130823599): Move to tools/evaluation/utils.
TfLiteStatus GetSortedFileNames(const std::string dir_path,
std::vector<std::string>* result) {
DIR* dir;
struct dirent* ent;
if (result == nullptr) {
LOG(ERROR) << "result cannot be nullptr";
return kTfLiteError;
}
if ((dir = opendir(dir_path.c_str())) != nullptr) {
while ((ent = readdir(dir)) != nullptr) {
std::string filename(std::string(ent->d_name));
if (filename.size() <= 2) continue;
result->emplace_back(dir_path + "/" + filename);
}
closedir(dir);
} else {
LOG(ERROR) << "Could not open dir: " << dir_path;
return kTfLiteError;
}
std::sort(result->begin(), result->end());
return kTfLiteOk;
}
TfLiteStatus ImagenetModelEvaluator::EvaluateModel() const {
const std::string data_path =
StripTrailingSlashes(params_.ground_truth_images_path) + "/";
const std::string data_path = tflite::evaluation::StripTrailingSlashes(
params_.ground_truth_images_path) +
"/";
std::vector<std::string> image_files;
TF_LITE_ENSURE_STATUS(GetSortedFileNames(data_path, &image_files));
TF_LITE_ENSURE_STATUS(
tflite::evaluation::GetSortedFileNames(data_path, &image_files));
std::vector<string> ground_truth_image_labels;
if (!tflite::evaluation::ReadFileLines(params_.ground_truth_labels_path,
&ground_truth_image_labels))

View File

@ -40,6 +40,7 @@ cc_library(
copts = tflite_copts(),
deps = [
"//tensorflow/core:tflite_portable_logging",
"//tensorflow/lite:context",
"//tensorflow/lite:framework",
"//tensorflow/lite/delegates/nnapi:nnapi_delegate",
] + select({
@ -53,11 +54,15 @@ cc_library(
cc_test(
name = "utils_test",
srcs = ["utils_test.cc"],
data = ["testdata/labels.txt"],
data = [
"testdata/empty.txt",
"testdata/labels.txt",
],
linkopts = tflite_linkopts(),
linkstatic = 1,
deps = [
":utils",
"//tensorflow/lite:context",
"@com_google_googletest//:gtest_main",
],
)

View File

@ -30,6 +30,11 @@ cc_proto_library(
deps = ["evaluation_stages_proto"],
)
java_proto_library(
name = "evaluation_stages_java_proto",
deps = ["evaluation_stages_proto"],
)
proto_library(
name = "evaluation_config_proto",
srcs = [
@ -43,3 +48,8 @@ cc_proto_library(
name = "evaluation_config_cc_proto",
deps = ["evaluation_config_proto"],
)
java_proto_library(
name = "evaluation_config_java_proto",
deps = ["evaluation_config_proto"],
)

View File

@ -19,6 +19,10 @@ package tflite.evaluation;
import "tensorflow/lite/tools/evaluation/proto/evaluation_stages.proto";
option cc_enable_arenas = true;
option java_multiple_files = true;
option java_package = "tflite.evaluation";
// Contains parameters that define how an EvaluationStage will be executed.
// This would typically be validated only once during initialization, so should
// not contain any variables that change with each run.

View File

@ -17,6 +17,10 @@ syntax = "proto2";
package tflite.evaluation;
option cc_enable_arenas = true;
option java_multiple_files = true;
option java_package = "tflite.evaluation";
// Defines the functionality executed by an EvaluationStage.
//
// Next ID: 5

View File

View File

@ -15,8 +15,10 @@ limitations under the License.
#include "tensorflow/lite/tools/evaluation/utils.h"
#include <dirent.h>
#include <sys/stat.h>
#include <algorithm>
#include <fstream>
#include <memory>
#include <string>
@ -31,6 +33,14 @@ limitations under the License.
namespace tflite {
namespace evaluation {
std::string StripTrailingSlashes(const std::string& path) {
int end = path.size();
while (end > 0 && path[end - 1] == '/') {
end--;
}
return path.substr(0, end);
}
bool ReadFileLines(const std::string& file_path,
std::vector<std::string>* lines_output) {
if (!lines_output) {
@ -49,6 +59,31 @@ bool ReadFileLines(const std::string& file_path,
return true;
}
TfLiteStatus GetSortedFileNames(const std::string& directory,
std::vector<std::string>* result) {
DIR* dir;
struct dirent* ent;
if (result == nullptr) {
LOG(ERROR) << "result cannot be nullptr";
return kTfLiteError;
}
result->clear();
std::string dir_path = StripTrailingSlashes(directory);
if ((dir = opendir(dir_path.c_str())) != nullptr) {
while ((ent = readdir(dir)) != nullptr) {
std::string filename(std::string(ent->d_name));
if (filename.size() <= 2) continue;
result->emplace_back(dir_path + "/" + filename);
}
closedir(dir);
} else {
LOG(ERROR) << "Could not open dir: " << dir_path;
return kTfLiteError;
}
std::sort(result->begin(), result->end());
return kTfLiteOk;
}
Interpreter::TfLiteDelegatePtr CreateNNAPIDelegate() {
#if defined(__ANDROID__)
return Interpreter::TfLiteDelegatePtr(

View File

@ -19,13 +19,19 @@ limitations under the License.
#include <string>
#include <vector>
#include "tensorflow/lite/context.h"
#include "tensorflow/lite/model.h"
namespace tflite {
namespace evaluation {
std::string StripTrailingSlashes(const std::string& path);
bool ReadFileLines(const std::string& file_path,
std::vector<std::string>* lines_output);
TfLiteStatus GetSortedFileNames(const std::string& directory,
std::vector<std::string>* result);
Interpreter::TfLiteDelegatePtr CreateNNAPIDelegate();
Interpreter::TfLiteDelegatePtr CreateGPUDelegate(FlatBufferModel* model);

View File

@ -18,16 +18,32 @@ limitations under the License.
#include <vector>
#include <gtest/gtest.h>
#include "tensorflow/lite/context.h"
namespace tflite {
namespace evaluation {
namespace {
constexpr char kFilePath[] =
constexpr char kLabelsPath[] =
"tensorflow/lite/tools/evaluation/testdata/labels.txt";
constexpr char kDirPath[] =
"tensorflow/lite/tools/evaluation/testdata";
constexpr char kEmptyFilePath[] =
"tensorflow/lite/tools/evaluation/testdata/empty.txt";
TEST(UtilsTest, StripTrailingSlashesTest) {
std::string path = "/usr/local/folder/";
EXPECT_EQ(StripTrailingSlashes(path), "/usr/local/folder");
path = "/usr/local/folder";
EXPECT_EQ(StripTrailingSlashes(path), path);
path = "folder";
EXPECT_EQ(StripTrailingSlashes(path), path);
}
TEST(UtilsTest, ReadFileErrors) {
std::string correct_path(kFilePath);
std::string correct_path(kLabelsPath);
std::string wrong_path("xyz.txt");
std::vector<std::string> lines;
EXPECT_FALSE(ReadFileLines(correct_path, nullptr));
@ -35,7 +51,7 @@ TEST(UtilsTest, ReadFileErrors) {
}
TEST(UtilsTest, ReadFileCorrectly) {
std::string file_path(kFilePath);
std::string file_path(kLabelsPath);
std::vector<std::string> lines;
EXPECT_TRUE(ReadFileLines(file_path, &lines));
@ -44,6 +60,17 @@ TEST(UtilsTest, ReadFileCorrectly) {
EXPECT_EQ(lines[1], "label2");
}
TEST(UtilsTest, SortedFilenamesTest) {
std::vector<std::string> files;
EXPECT_EQ(GetSortedFileNames(kDirPath, &files), kTfLiteOk);
EXPECT_EQ(files.size(), 2);
EXPECT_EQ(files[0], kEmptyFilePath);
EXPECT_EQ(files[1], kLabelsPath);
EXPECT_EQ(GetSortedFileNames("wrong_path", &files), kTfLiteError);
}
} // namespace
} // namespace evaluation
} // namespace tflite