diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD index 9e152aa0823..434626bd2da 100644 --- a/tensorflow/core/distributed_runtime/BUILD +++ b/tensorflow/core/distributed_runtime/BUILD @@ -595,6 +595,7 @@ tf_cc_test( srcs = ["recent_request_ids_test.cc"], deps = [ ":recent_request_ids", + ":request_id", "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", diff --git a/tensorflow/core/distributed_runtime/recent_request_ids.cc b/tensorflow/core/distributed_runtime/recent_request_ids.cc index c30879406c6..4f6866c5d15 100644 --- a/tensorflow/core/distributed_runtime/recent_request_ids.cc +++ b/tensorflow/core/distributed_runtime/recent_request_ids.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/recent_request_ids.h" +#include + #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" @@ -29,12 +31,14 @@ RecentRequestIds::RecentRequestIds(int num_tracked_request_ids) Status RecentRequestIds::TrackUnique(int64 request_id, const string& method_name, const protobuf::Message& request) { - mutex_lock l(mu_); if (request_id == 0) { // For backwards compatibility, allow all requests with request_id 0. return Status::OK(); } - if (set_.count(request_id) > 0) { + + mutex_lock l(mu_); + const bool inserted = set_.insert(request_id).second; + if (!inserted) { // Note: RecentRequestIds is not strict LRU because we don't update // request_id's age in the circular_buffer_ if it's tracked again. Strict // LRU is not useful here because returning this error will close the @@ -49,7 +53,6 @@ Status RecentRequestIds::TrackUnique(int64 request_id, // when the buffer is not yet full. set_.erase(circular_buffer_[next_index_]); circular_buffer_[next_index_] = request_id; - set_.insert(request_id); next_index_ = (next_index_ + 1) % circular_buffer_.size(); return Status::OK(); } diff --git a/tensorflow/core/distributed_runtime/recent_request_ids.h b/tensorflow/core/distributed_runtime/recent_request_ids.h index e8e45331dd5..11cf937c946 100644 --- a/tensorflow/core/distributed_runtime/recent_request_ids.h +++ b/tensorflow/core/distributed_runtime/recent_request_ids.h @@ -16,11 +16,13 @@ limitations under the License. #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RECENT_REQUEST_IDS_H_ #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RECENT_REQUEST_IDS_H_ +#include +#include #include #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/protobuf/worker.pb.h" @@ -64,7 +66,7 @@ class RecentRequestIds { // request_id. int next_index_ GUARDED_BY(mu_) = 0; std::vector circular_buffer_ GUARDED_BY(mu_); - gtl::FlatSet set_ GUARDED_BY(mu_); + std::unordered_set set_ GUARDED_BY(mu_); }; } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/recent_request_ids_test.cc b/tensorflow/core/distributed_runtime/recent_request_ids_test.cc index 9a0facf5404..8910a50e9cd 100644 --- a/tensorflow/core/distributed_runtime/recent_request_ids_test.cc +++ b/tensorflow/core/distributed_runtime/recent_request_ids_test.cc @@ -17,8 +17,10 @@ limitations under the License. #include +#include "tensorflow/core/distributed_runtime/request_id.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/protobuf/worker.pb.h" @@ -93,4 +95,15 @@ TEST(RecentRequestIds, Ordered3) { TestOrdered(3); } TEST(RecentRequestIds, Ordered4) { TestOrdered(4); } TEST(RecentRequestIds, Ordered5) { TestOrdered(5); } +void BM_TrackUnique(int iters) { + RecentRequestIds recent_request_ids(100000); + RecvTensorRequest request; + for (int i = 0; i < iters; ++i) { + TF_CHECK_OK(recent_request_ids.TrackUnique(GetUniqueRequestId(), + "BM_TrackUnique", request)); + } +} + +BENCHMARK(BM_TrackUnique); + } // namespace tensorflow