Make RecentRequestIds more efficient.

PiperOrigin-RevId: 187242940
This commit is contained in:
Jeremy Lau 2018-02-27 15:32:16 -08:00 committed by TensorFlower Gardener
parent 2c25f08b6f
commit 53b2181ea5
4 changed files with 24 additions and 5 deletions

View File

@ -595,6 +595,7 @@ tf_cc_test(
srcs = ["recent_request_ids_test.cc"],
deps = [
":recent_request_ids",
":request_id",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",

View File

@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/recent_request_ids.h"
#include <utility>
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
@ -29,12 +31,14 @@ RecentRequestIds::RecentRequestIds(int num_tracked_request_ids)
Status RecentRequestIds::TrackUnique(int64 request_id,
const string& method_name,
const protobuf::Message& request) {
mutex_lock l(mu_);
if (request_id == 0) {
// For backwards compatibility, allow all requests with request_id 0.
return Status::OK();
}
if (set_.count(request_id) > 0) {
mutex_lock l(mu_);
const bool inserted = set_.insert(request_id).second;
if (!inserted) {
// Note: RecentRequestIds is not strict LRU because we don't update
// request_id's age in the circular_buffer_ if it's tracked again. Strict
// LRU is not useful here because returning this error will close the
@ -49,7 +53,6 @@ Status RecentRequestIds::TrackUnique(int64 request_id,
// when the buffer is not yet full.
set_.erase(circular_buffer_[next_index_]);
circular_buffer_[next_index_] = request_id;
set_.insert(request_id);
next_index_ = (next_index_ + 1) % circular_buffer_.size();
return Status::OK();
}

View File

@ -16,11 +16,13 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RECENT_REQUEST_IDS_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RECENT_REQUEST_IDS_H_
#include <string>
#include <unordered_set>
#include <vector>
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/worker.pb.h"
@ -64,7 +66,7 @@ class RecentRequestIds {
// request_id.
int next_index_ GUARDED_BY(mu_) = 0;
std::vector<int64> circular_buffer_ GUARDED_BY(mu_);
gtl::FlatSet<int64> set_ GUARDED_BY(mu_);
std::unordered_set<int64> set_ GUARDED_BY(mu_);
};
} // namespace tensorflow

View File

@ -17,8 +17,10 @@ limitations under the License.
#include <algorithm>
#include "tensorflow/core/distributed_runtime/request_id.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/worker.pb.h"
@ -93,4 +95,15 @@ TEST(RecentRequestIds, Ordered3) { TestOrdered(3); }
TEST(RecentRequestIds, Ordered4) { TestOrdered(4); }
TEST(RecentRequestIds, Ordered5) { TestOrdered(5); }
void BM_TrackUnique(int iters) {
RecentRequestIds recent_request_ids(100000);
RecvTensorRequest request;
for (int i = 0; i < iters; ++i) {
TF_CHECK_OK(recent_request_ids.TrackUnique(GetUniqueRequestId(),
"BM_TrackUnique", request));
}
}
BENCHMARK(BM_TrackUnique);
} // namespace tensorflow