From 53b2181ea5cff054d40c583f05da942a9a56a283 Mon Sep 17 00:00:00 2001
From: Jeremy Lau <lauj@google.com>
Date: Tue, 27 Feb 2018 15:32:16 -0800
Subject: [PATCH] Make RecentRequestIds more efficient.

PiperOrigin-RevId: 187242940
---
 tensorflow/core/distributed_runtime/BUILD           |  1 +
 .../core/distributed_runtime/recent_request_ids.cc  |  9 ++++++---
 .../core/distributed_runtime/recent_request_ids.h   |  6 ++++--
 .../distributed_runtime/recent_request_ids_test.cc  | 13 +++++++++++++
 4 files changed, 24 insertions(+), 5 deletions(-)

diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD
index 9e152aa0823..434626bd2da 100644
--- a/tensorflow/core/distributed_runtime/BUILD
+++ b/tensorflow/core/distributed_runtime/BUILD
@@ -595,6 +595,7 @@ tf_cc_test(
     srcs = ["recent_request_ids_test.cc"],
     deps = [
         ":recent_request_ids",
+        ":request_id",
         "//tensorflow/core:lib",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
diff --git a/tensorflow/core/distributed_runtime/recent_request_ids.cc b/tensorflow/core/distributed_runtime/recent_request_ids.cc
index c30879406c6..4f6866c5d15 100644
--- a/tensorflow/core/distributed_runtime/recent_request_ids.cc
+++ b/tensorflow/core/distributed_runtime/recent_request_ids.cc
@@ -15,6 +15,8 @@ limitations under the License.
 
 #include "tensorflow/core/distributed_runtime/recent_request_ids.h"
 
+#include <utility>
+
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/strings/strcat.h"
 #include "tensorflow/core/platform/logging.h"
@@ -29,12 +31,14 @@ RecentRequestIds::RecentRequestIds(int num_tracked_request_ids)
 Status RecentRequestIds::TrackUnique(int64 request_id,
                                      const string& method_name,
                                      const protobuf::Message& request) {
-  mutex_lock l(mu_);
   if (request_id == 0) {
     // For backwards compatibility, allow all requests with request_id 0.
     return Status::OK();
   }
-  if (set_.count(request_id) > 0) {
+
+  mutex_lock l(mu_);
+  const bool inserted = set_.insert(request_id).second;
+  if (!inserted) {
     // Note: RecentRequestIds is not strict LRU because we don't update
     // request_id's age in the circular_buffer_ if it's tracked again. Strict
     // LRU is not useful here because returning this error will close the
@@ -49,7 +53,6 @@ Status RecentRequestIds::TrackUnique(int64 request_id,
   // when the buffer is not yet full.
   set_.erase(circular_buffer_[next_index_]);
   circular_buffer_[next_index_] = request_id;
-  set_.insert(request_id);
   next_index_ = (next_index_ + 1) % circular_buffer_.size();
   return Status::OK();
 }
diff --git a/tensorflow/core/distributed_runtime/recent_request_ids.h b/tensorflow/core/distributed_runtime/recent_request_ids.h
index e8e45331dd5..11cf937c946 100644
--- a/tensorflow/core/distributed_runtime/recent_request_ids.h
+++ b/tensorflow/core/distributed_runtime/recent_request_ids.h
@@ -16,11 +16,13 @@ limitations under the License.
 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RECENT_REQUEST_IDS_H_
 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RECENT_REQUEST_IDS_H_
 
+#include <string>
+#include <unordered_set>
 #include <vector>
 
 #include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
 #include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/protobuf.h"
 #include "tensorflow/core/platform/thread_annotations.h"
 #include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/protobuf/worker.pb.h"
@@ -64,7 +66,7 @@ class RecentRequestIds {
   // request_id.
   int next_index_ GUARDED_BY(mu_) = 0;
   std::vector<int64> circular_buffer_ GUARDED_BY(mu_);
-  gtl::FlatSet<int64> set_ GUARDED_BY(mu_);
+  std::unordered_set<int64> set_ GUARDED_BY(mu_);
 };
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/recent_request_ids_test.cc b/tensorflow/core/distributed_runtime/recent_request_ids_test.cc
index 9a0facf5404..8910a50e9cd 100644
--- a/tensorflow/core/distributed_runtime/recent_request_ids_test.cc
+++ b/tensorflow/core/distributed_runtime/recent_request_ids_test.cc
@@ -17,8 +17,10 @@ limitations under the License.
 
 #include <algorithm>
 
+#include "tensorflow/core/distributed_runtime/request_id.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
 #include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/protobuf/worker.pb.h"
 
@@ -93,4 +95,15 @@ TEST(RecentRequestIds, Ordered3) { TestOrdered(3); }
 TEST(RecentRequestIds, Ordered4) { TestOrdered(4); }
 TEST(RecentRequestIds, Ordered5) { TestOrdered(5); }
 
+void BM_TrackUnique(int iters) {
+  RecentRequestIds recent_request_ids(100000);
+  RecvTensorRequest request;
+  for (int i = 0; i < iters; ++i) {
+    TF_CHECK_OK(recent_request_ids.TrackUnique(GetUniqueRequestId(),
+                                               "BM_TrackUnique", request));
+  }
+}
+
+BENCHMARK(BM_TrackUnique);
+
 }  // namespace tensorflow