From efcb8bba816fd734091739aab7e999e5a13ba17f Mon Sep 17 00:00:00 2001
From: Sachin Joglekar <srjoglekar@google.com>
Date: Tue, 30 Apr 2019 09:31:24 -0700
Subject: [PATCH] Internal Change

PiperOrigin-RevId: 245962899
---
 .../ilsvrc/imagenet_model_evaluator.cc        | 40 +++----------------
 tensorflow/lite/tools/evaluation/BUILD        |  7 +++-
 tensorflow/lite/tools/evaluation/proto/BUILD  | 10 +++++
 .../evaluation/proto/evaluation_config.proto  |  4 ++
 .../evaluation/proto/evaluation_stages.proto  |  4 ++
 .../lite/tools/evaluation/testdata/empty.txt  |  0
 tensorflow/lite/tools/evaluation/utils.cc     | 35 ++++++++++++++++
 tensorflow/lite/tools/evaluation/utils.h      |  6 +++
 .../lite/tools/evaluation/utils_test.cc       | 33 +++++++++++++--
 9 files changed, 100 insertions(+), 39 deletions(-)
 create mode 100644 tensorflow/lite/tools/evaluation/testdata/empty.txt

diff --git a/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.cc b/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.cc
index 0d0865dc8fc..ecbd8a7234c 100644
--- a/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.cc
+++ b/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.cc
@@ -42,14 +42,6 @@ constexpr char kGroundTruthLabelsFlag[] = "ground_truth_labels";
 constexpr char kBlacklistFilePathFlag[] = "blacklist_file_path";
 constexpr char kModelFileFlag[] = "model_file";
 
-std::string StripTrailingSlashes(const std::string& path) {
-  int end = path.size();
-  while (end > 0 && path[end - 1] == '/') {
-    end--;
-  }
-  return path.substr(0, end);
-}
-
 template <typename T>
 std::vector<T> GetFirstN(const std::vector<T>& v, int n) {
   if (n >= v.size()) return v;
@@ -232,35 +224,13 @@ TfLiteStatus FilterBlackListedImages(const std::string& blacklist_file_path,
   return kTfLiteOk;
 }
 
-// TODO(b/130823599): Move to tools/evaluation/utils.
-TfLiteStatus GetSortedFileNames(const std::string dir_path,
-                                std::vector<std::string>* result) {
-  DIR* dir;
-  struct dirent* ent;
-  if (result == nullptr) {
-    LOG(ERROR) << "result cannot be nullptr";
-    return kTfLiteError;
-  }
-  if ((dir = opendir(dir_path.c_str())) != nullptr) {
-    while ((ent = readdir(dir)) != nullptr) {
-      std::string filename(std::string(ent->d_name));
-      if (filename.size() <= 2) continue;
-      result->emplace_back(dir_path + "/" + filename);
-    }
-    closedir(dir);
-  } else {
-    LOG(ERROR) << "Could not open dir: " << dir_path;
-    return kTfLiteError;
-  }
-  std::sort(result->begin(), result->end());
-  return kTfLiteOk;
-}
-
 TfLiteStatus ImagenetModelEvaluator::EvaluateModel() const {
-  const std::string data_path =
-      StripTrailingSlashes(params_.ground_truth_images_path) + "/";
+  const std::string data_path = tflite::evaluation::StripTrailingSlashes(
+                                    params_.ground_truth_images_path) +
+                                "/";
   std::vector<std::string> image_files;
-  TF_LITE_ENSURE_STATUS(GetSortedFileNames(data_path, &image_files));
+  TF_LITE_ENSURE_STATUS(
+      tflite::evaluation::GetSortedFileNames(data_path, &image_files));
   std::vector<string> ground_truth_image_labels;
   if (!tflite::evaluation::ReadFileLines(params_.ground_truth_labels_path,
                                          &ground_truth_image_labels))
diff --git a/tensorflow/lite/tools/evaluation/BUILD b/tensorflow/lite/tools/evaluation/BUILD
index c68f2e2d319..845ded49812 100644
--- a/tensorflow/lite/tools/evaluation/BUILD
+++ b/tensorflow/lite/tools/evaluation/BUILD
@@ -40,6 +40,7 @@ cc_library(
     copts = tflite_copts(),
     deps = [
         "//tensorflow/core:tflite_portable_logging",
+        "//tensorflow/lite:context",
         "//tensorflow/lite:framework",
         "//tensorflow/lite/delegates/nnapi:nnapi_delegate",
     ] + select({
@@ -53,11 +54,15 @@ cc_library(
 cc_test(
     name = "utils_test",
     srcs = ["utils_test.cc"],
-    data = ["testdata/labels.txt"],
+    data = [
+        "testdata/empty.txt",
+        "testdata/labels.txt",
+    ],
     linkopts = tflite_linkopts(),
     linkstatic = 1,
     deps = [
         ":utils",
+        "//tensorflow/lite:context",
         "@com_google_googletest//:gtest_main",
     ],
 )
diff --git a/tensorflow/lite/tools/evaluation/proto/BUILD b/tensorflow/lite/tools/evaluation/proto/BUILD
index d0fc459f345..fd1f0209f3e 100644
--- a/tensorflow/lite/tools/evaluation/proto/BUILD
+++ b/tensorflow/lite/tools/evaluation/proto/BUILD
@@ -30,6 +30,11 @@ cc_proto_library(
     deps = ["evaluation_stages_proto"],
 )
 
+java_proto_library(
+    name = "evaluation_stages_java_proto",
+    deps = ["evaluation_stages_proto"],
+)
+
 proto_library(
     name = "evaluation_config_proto",
     srcs = [
@@ -43,3 +48,8 @@ cc_proto_library(
     name = "evaluation_config_cc_proto",
     deps = ["evaluation_config_proto"],
 )
+
+java_proto_library(
+    name = "evaluation_config_java_proto",
+    deps = ["evaluation_config_proto"],
+)
diff --git a/tensorflow/lite/tools/evaluation/proto/evaluation_config.proto b/tensorflow/lite/tools/evaluation/proto/evaluation_config.proto
index b69ad6c306a..f95892c8bcc 100644
--- a/tensorflow/lite/tools/evaluation/proto/evaluation_config.proto
+++ b/tensorflow/lite/tools/evaluation/proto/evaluation_config.proto
@@ -19,6 +19,10 @@ package tflite.evaluation;
 
 import "tensorflow/lite/tools/evaluation/proto/evaluation_stages.proto";
 
+option cc_enable_arenas = true;
+option java_multiple_files = true;
+option java_package = "tflite.evaluation";
+
 // Contains parameters that define how an EvaluationStage will be executed.
 // This would typically be validated only once during initialization, so should
 // not contain any variables that change with each run.
diff --git a/tensorflow/lite/tools/evaluation/proto/evaluation_stages.proto b/tensorflow/lite/tools/evaluation/proto/evaluation_stages.proto
index a21fe5ca2f0..6c01787e6fd 100644
--- a/tensorflow/lite/tools/evaluation/proto/evaluation_stages.proto
+++ b/tensorflow/lite/tools/evaluation/proto/evaluation_stages.proto
@@ -17,6 +17,10 @@ syntax = "proto2";
 
 package tflite.evaluation;
 
+option cc_enable_arenas = true;
+option java_multiple_files = true;
+option java_package = "tflite.evaluation";
+
 // Defines the functionality executed by an EvaluationStage.
 //
 // Next ID: 5
diff --git a/tensorflow/lite/tools/evaluation/testdata/empty.txt b/tensorflow/lite/tools/evaluation/testdata/empty.txt
new file mode 100644
index 00000000000..e69de29bb2d
diff --git a/tensorflow/lite/tools/evaluation/utils.cc b/tensorflow/lite/tools/evaluation/utils.cc
index 55c33dcfa7a..7b6b3402463 100644
--- a/tensorflow/lite/tools/evaluation/utils.cc
+++ b/tensorflow/lite/tools/evaluation/utils.cc
@@ -15,8 +15,10 @@ limitations under the License.
 
 #include "tensorflow/lite/tools/evaluation/utils.h"
 
+#include <dirent.h>
 #include <sys/stat.h>
 
+#include <algorithm>
 #include <fstream>
 #include <memory>
 #include <string>
@@ -31,6 +33,14 @@ limitations under the License.
 namespace tflite {
 namespace evaluation {
 
+std::string StripTrailingSlashes(const std::string& path) {
+  int end = path.size();
+  while (end > 0 && path[end - 1] == '/') {
+    end--;
+  }
+  return path.substr(0, end);
+}
+
 bool ReadFileLines(const std::string& file_path,
                    std::vector<std::string>* lines_output) {
   if (!lines_output) {
@@ -49,6 +59,31 @@ bool ReadFileLines(const std::string& file_path,
   return true;
 }
 
+TfLiteStatus GetSortedFileNames(const std::string& directory,
+                                std::vector<std::string>* result) {
+  DIR* dir;
+  struct dirent* ent;
+  if (result == nullptr) {
+    LOG(ERROR) << "result cannot be nullptr";
+    return kTfLiteError;
+  }
+  result->clear();
+  std::string dir_path = StripTrailingSlashes(directory);
+  if ((dir = opendir(dir_path.c_str())) != nullptr) {
+    while ((ent = readdir(dir)) != nullptr) {
+      std::string filename(std::string(ent->d_name));
+      if (filename.size() <= 2) continue;
+      result->emplace_back(dir_path + "/" + filename);
+    }
+    closedir(dir);
+  } else {
+    LOG(ERROR) << "Could not open dir: " << dir_path;
+    return kTfLiteError;
+  }
+  std::sort(result->begin(), result->end());
+  return kTfLiteOk;
+}
+
 Interpreter::TfLiteDelegatePtr CreateNNAPIDelegate() {
 #if defined(__ANDROID__)
   return Interpreter::TfLiteDelegatePtr(
diff --git a/tensorflow/lite/tools/evaluation/utils.h b/tensorflow/lite/tools/evaluation/utils.h
index fef44dcdb94..1e2dbe0b765 100644
--- a/tensorflow/lite/tools/evaluation/utils.h
+++ b/tensorflow/lite/tools/evaluation/utils.h
@@ -19,13 +19,19 @@ limitations under the License.
 #include <string>
 #include <vector>
 
+#include "tensorflow/lite/context.h"
 #include "tensorflow/lite/model.h"
 
 namespace tflite {
 namespace evaluation {
+std::string StripTrailingSlashes(const std::string& path);
+
 bool ReadFileLines(const std::string& file_path,
                    std::vector<std::string>* lines_output);
 
+TfLiteStatus GetSortedFileNames(const std::string& directory,
+                                std::vector<std::string>* result);
+
 Interpreter::TfLiteDelegatePtr CreateNNAPIDelegate();
 
 Interpreter::TfLiteDelegatePtr CreateGPUDelegate(FlatBufferModel* model);
diff --git a/tensorflow/lite/tools/evaluation/utils_test.cc b/tensorflow/lite/tools/evaluation/utils_test.cc
index 6406db7e861..498de13197a 100644
--- a/tensorflow/lite/tools/evaluation/utils_test.cc
+++ b/tensorflow/lite/tools/evaluation/utils_test.cc
@@ -18,16 +18,32 @@ limitations under the License.
 #include <vector>
 
 #include <gtest/gtest.h>
+#include "tensorflow/lite/context.h"
 
 namespace tflite {
 namespace evaluation {
 namespace {
 
-constexpr char kFilePath[] =
+constexpr char kLabelsPath[] =
     "tensorflow/lite/tools/evaluation/testdata/labels.txt";
+constexpr char kDirPath[] =
+    "tensorflow/lite/tools/evaluation/testdata";
+constexpr char kEmptyFilePath[] =
+    "tensorflow/lite/tools/evaluation/testdata/empty.txt";
+
+TEST(UtilsTest, StripTrailingSlashesTest) {
+  std::string path = "/usr/local/folder/";
+  EXPECT_EQ(StripTrailingSlashes(path), "/usr/local/folder");
+
+  path = "/usr/local/folder";
+  EXPECT_EQ(StripTrailingSlashes(path), path);
+
+  path = "folder";
+  EXPECT_EQ(StripTrailingSlashes(path), path);
+}
 
 TEST(UtilsTest, ReadFileErrors) {
-  std::string correct_path(kFilePath);
+  std::string correct_path(kLabelsPath);
   std::string wrong_path("xyz.txt");
   std::vector<std::string> lines;
   EXPECT_FALSE(ReadFileLines(correct_path, nullptr));
@@ -35,7 +51,7 @@ TEST(UtilsTest, ReadFileErrors) {
 }
 
 TEST(UtilsTest, ReadFileCorrectly) {
-  std::string file_path(kFilePath);
+  std::string file_path(kLabelsPath);
   std::vector<std::string> lines;
   EXPECT_TRUE(ReadFileLines(file_path, &lines));
 
@@ -44,6 +60,17 @@ TEST(UtilsTest, ReadFileCorrectly) {
   EXPECT_EQ(lines[1], "label2");
 }
 
+TEST(UtilsTest, SortedFilenamesTest) {
+  std::vector<std::string> files;
+  EXPECT_EQ(GetSortedFileNames(kDirPath, &files), kTfLiteOk);
+
+  EXPECT_EQ(files.size(), 2);
+  EXPECT_EQ(files[0], kEmptyFilePath);
+  EXPECT_EQ(files[1], kLabelsPath);
+
+  EXPECT_EQ(GetSortedFileNames("wrong_path", &files), kTfLiteError);
+}
+
 }  // namespace
 }  // namespace evaluation
 }  // namespace tflite