Internal Change

PiperOrigin-RevId: 245962899
This commit is contained in:
Sachin Joglekar 2019-04-30 09:31:24 -07:00 committed by TensorFlower Gardener
parent aa7d77dedc
commit efcb8bba81
9 changed files with 100 additions and 39 deletions

View File

@ -42,14 +42,6 @@ constexpr char kGroundTruthLabelsFlag[] = "ground_truth_labels";
constexpr char kBlacklistFilePathFlag[] = "blacklist_file_path"; constexpr char kBlacklistFilePathFlag[] = "blacklist_file_path";
constexpr char kModelFileFlag[] = "model_file"; constexpr char kModelFileFlag[] = "model_file";
std::string StripTrailingSlashes(const std::string& path) {
int end = path.size();
while (end > 0 && path[end - 1] == '/') {
end--;
}
return path.substr(0, end);
}
template <typename T> template <typename T>
std::vector<T> GetFirstN(const std::vector<T>& v, int n) { std::vector<T> GetFirstN(const std::vector<T>& v, int n) {
if (n >= v.size()) return v; if (n >= v.size()) return v;
@ -232,35 +224,13 @@ TfLiteStatus FilterBlackListedImages(const std::string& blacklist_file_path,
return kTfLiteOk; return kTfLiteOk;
} }
// TODO(b/130823599): Move to tools/evaluation/utils.
TfLiteStatus GetSortedFileNames(const std::string dir_path,
std::vector<std::string>* result) {
DIR* dir;
struct dirent* ent;
if (result == nullptr) {
LOG(ERROR) << "result cannot be nullptr";
return kTfLiteError;
}
if ((dir = opendir(dir_path.c_str())) != nullptr) {
while ((ent = readdir(dir)) != nullptr) {
std::string filename(std::string(ent->d_name));
if (filename.size() <= 2) continue;
result->emplace_back(dir_path + "/" + filename);
}
closedir(dir);
} else {
LOG(ERROR) << "Could not open dir: " << dir_path;
return kTfLiteError;
}
std::sort(result->begin(), result->end());
return kTfLiteOk;
}
TfLiteStatus ImagenetModelEvaluator::EvaluateModel() const { TfLiteStatus ImagenetModelEvaluator::EvaluateModel() const {
const std::string data_path = const std::string data_path = tflite::evaluation::StripTrailingSlashes(
StripTrailingSlashes(params_.ground_truth_images_path) + "/"; params_.ground_truth_images_path) +
"/";
std::vector<std::string> image_files; std::vector<std::string> image_files;
TF_LITE_ENSURE_STATUS(GetSortedFileNames(data_path, &image_files)); TF_LITE_ENSURE_STATUS(
tflite::evaluation::GetSortedFileNames(data_path, &image_files));
std::vector<string> ground_truth_image_labels; std::vector<string> ground_truth_image_labels;
if (!tflite::evaluation::ReadFileLines(params_.ground_truth_labels_path, if (!tflite::evaluation::ReadFileLines(params_.ground_truth_labels_path,
&ground_truth_image_labels)) &ground_truth_image_labels))

View File

@ -40,6 +40,7 @@ cc_library(
copts = tflite_copts(), copts = tflite_copts(),
deps = [ deps = [
"//tensorflow/core:tflite_portable_logging", "//tensorflow/core:tflite_portable_logging",
"//tensorflow/lite:context",
"//tensorflow/lite:framework", "//tensorflow/lite:framework",
"//tensorflow/lite/delegates/nnapi:nnapi_delegate", "//tensorflow/lite/delegates/nnapi:nnapi_delegate",
] + select({ ] + select({
@ -53,11 +54,15 @@ cc_library(
cc_test( cc_test(
name = "utils_test", name = "utils_test",
srcs = ["utils_test.cc"], srcs = ["utils_test.cc"],
data = ["testdata/labels.txt"], data = [
"testdata/empty.txt",
"testdata/labels.txt",
],
linkopts = tflite_linkopts(), linkopts = tflite_linkopts(),
linkstatic = 1, linkstatic = 1,
deps = [ deps = [
":utils", ":utils",
"//tensorflow/lite:context",
"@com_google_googletest//:gtest_main", "@com_google_googletest//:gtest_main",
], ],
) )

View File

@ -30,6 +30,11 @@ cc_proto_library(
deps = ["evaluation_stages_proto"], deps = ["evaluation_stages_proto"],
) )
java_proto_library(
name = "evaluation_stages_java_proto",
deps = ["evaluation_stages_proto"],
)
proto_library( proto_library(
name = "evaluation_config_proto", name = "evaluation_config_proto",
srcs = [ srcs = [
@ -43,3 +48,8 @@ cc_proto_library(
name = "evaluation_config_cc_proto", name = "evaluation_config_cc_proto",
deps = ["evaluation_config_proto"], deps = ["evaluation_config_proto"],
) )
java_proto_library(
name = "evaluation_config_java_proto",
deps = ["evaluation_config_proto"],
)

View File

@ -19,6 +19,10 @@ package tflite.evaluation;
import "tensorflow/lite/tools/evaluation/proto/evaluation_stages.proto"; import "tensorflow/lite/tools/evaluation/proto/evaluation_stages.proto";
option cc_enable_arenas = true;
option java_multiple_files = true;
option java_package = "tflite.evaluation";
// Contains parameters that define how an EvaluationStage will be executed. // Contains parameters that define how an EvaluationStage will be executed.
// This would typically be validated only once during initialization, so should // This would typically be validated only once during initialization, so should
// not contain any variables that change with each run. // not contain any variables that change with each run.

View File

@ -17,6 +17,10 @@ syntax = "proto2";
package tflite.evaluation; package tflite.evaluation;
option cc_enable_arenas = true;
option java_multiple_files = true;
option java_package = "tflite.evaluation";
// Defines the functionality executed by an EvaluationStage. // Defines the functionality executed by an EvaluationStage.
// //
// Next ID: 5 // Next ID: 5

View File

View File

@ -15,8 +15,10 @@ limitations under the License.
#include "tensorflow/lite/tools/evaluation/utils.h" #include "tensorflow/lite/tools/evaluation/utils.h"
#include <dirent.h>
#include <sys/stat.h> #include <sys/stat.h>
#include <algorithm>
#include <fstream> #include <fstream>
#include <memory> #include <memory>
#include <string> #include <string>
@ -31,6 +33,14 @@ limitations under the License.
namespace tflite { namespace tflite {
namespace evaluation { namespace evaluation {
std::string StripTrailingSlashes(const std::string& path) {
int end = path.size();
while (end > 0 && path[end - 1] == '/') {
end--;
}
return path.substr(0, end);
}
bool ReadFileLines(const std::string& file_path, bool ReadFileLines(const std::string& file_path,
std::vector<std::string>* lines_output) { std::vector<std::string>* lines_output) {
if (!lines_output) { if (!lines_output) {
@ -49,6 +59,31 @@ bool ReadFileLines(const std::string& file_path,
return true; return true;
} }
TfLiteStatus GetSortedFileNames(const std::string& directory,
std::vector<std::string>* result) {
DIR* dir;
struct dirent* ent;
if (result == nullptr) {
LOG(ERROR) << "result cannot be nullptr";
return kTfLiteError;
}
result->clear();
std::string dir_path = StripTrailingSlashes(directory);
if ((dir = opendir(dir_path.c_str())) != nullptr) {
while ((ent = readdir(dir)) != nullptr) {
std::string filename(std::string(ent->d_name));
if (filename.size() <= 2) continue;
result->emplace_back(dir_path + "/" + filename);
}
closedir(dir);
} else {
LOG(ERROR) << "Could not open dir: " << dir_path;
return kTfLiteError;
}
std::sort(result->begin(), result->end());
return kTfLiteOk;
}
Interpreter::TfLiteDelegatePtr CreateNNAPIDelegate() { Interpreter::TfLiteDelegatePtr CreateNNAPIDelegate() {
#if defined(__ANDROID__) #if defined(__ANDROID__)
return Interpreter::TfLiteDelegatePtr( return Interpreter::TfLiteDelegatePtr(

View File

@ -19,13 +19,19 @@ limitations under the License.
#include <string> #include <string>
#include <vector> #include <vector>
#include "tensorflow/lite/context.h"
#include "tensorflow/lite/model.h" #include "tensorflow/lite/model.h"
namespace tflite { namespace tflite {
namespace evaluation { namespace evaluation {
std::string StripTrailingSlashes(const std::string& path);
bool ReadFileLines(const std::string& file_path, bool ReadFileLines(const std::string& file_path,
std::vector<std::string>* lines_output); std::vector<std::string>* lines_output);
TfLiteStatus GetSortedFileNames(const std::string& directory,
std::vector<std::string>* result);
Interpreter::TfLiteDelegatePtr CreateNNAPIDelegate(); Interpreter::TfLiteDelegatePtr CreateNNAPIDelegate();
Interpreter::TfLiteDelegatePtr CreateGPUDelegate(FlatBufferModel* model); Interpreter::TfLiteDelegatePtr CreateGPUDelegate(FlatBufferModel* model);

View File

@ -18,16 +18,32 @@ limitations under the License.
#include <vector> #include <vector>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "tensorflow/lite/context.h"
namespace tflite { namespace tflite {
namespace evaluation { namespace evaluation {
namespace { namespace {
constexpr char kFilePath[] = constexpr char kLabelsPath[] =
"tensorflow/lite/tools/evaluation/testdata/labels.txt"; "tensorflow/lite/tools/evaluation/testdata/labels.txt";
constexpr char kDirPath[] =
"tensorflow/lite/tools/evaluation/testdata";
constexpr char kEmptyFilePath[] =
"tensorflow/lite/tools/evaluation/testdata/empty.txt";
TEST(UtilsTest, StripTrailingSlashesTest) {
std::string path = "/usr/local/folder/";
EXPECT_EQ(StripTrailingSlashes(path), "/usr/local/folder");
path = "/usr/local/folder";
EXPECT_EQ(StripTrailingSlashes(path), path);
path = "folder";
EXPECT_EQ(StripTrailingSlashes(path), path);
}
TEST(UtilsTest, ReadFileErrors) { TEST(UtilsTest, ReadFileErrors) {
std::string correct_path(kFilePath); std::string correct_path(kLabelsPath);
std::string wrong_path("xyz.txt"); std::string wrong_path("xyz.txt");
std::vector<std::string> lines; std::vector<std::string> lines;
EXPECT_FALSE(ReadFileLines(correct_path, nullptr)); EXPECT_FALSE(ReadFileLines(correct_path, nullptr));
@ -35,7 +51,7 @@ TEST(UtilsTest, ReadFileErrors) {
} }
TEST(UtilsTest, ReadFileCorrectly) { TEST(UtilsTest, ReadFileCorrectly) {
std::string file_path(kFilePath); std::string file_path(kLabelsPath);
std::vector<std::string> lines; std::vector<std::string> lines;
EXPECT_TRUE(ReadFileLines(file_path, &lines)); EXPECT_TRUE(ReadFileLines(file_path, &lines));
@ -44,6 +60,17 @@ TEST(UtilsTest, ReadFileCorrectly) {
EXPECT_EQ(lines[1], "label2"); EXPECT_EQ(lines[1], "label2");
} }
TEST(UtilsTest, SortedFilenamesTest) {
std::vector<std::string> files;
EXPECT_EQ(GetSortedFileNames(kDirPath, &files), kTfLiteOk);
EXPECT_EQ(files.size(), 2);
EXPECT_EQ(files[0], kEmptyFilePath);
EXPECT_EQ(files[1], kLabelsPath);
EXPECT_EQ(GetSortedFileNames("wrong_path", &files), kTfLiteError);
}
} // namespace } // namespace
} // namespace evaluation } // namespace evaluation
} // namespace tflite } // namespace tflite