Make RecentRequestIds more efficient.
PiperOrigin-RevId: 187242940
This commit is contained in:
parent
2c25f08b6f
commit
53b2181ea5
@ -595,6 +595,7 @@ tf_cc_test(
|
||||
srcs = ["recent_request_ids_test.cc"],
|
||||
deps = [
|
||||
":recent_request_ids",
|
||||
":request_id",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
|
@ -15,6 +15,8 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/recent_request_ids.h"
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
@ -29,12 +31,14 @@ RecentRequestIds::RecentRequestIds(int num_tracked_request_ids)
|
||||
Status RecentRequestIds::TrackUnique(int64 request_id,
|
||||
const string& method_name,
|
||||
const protobuf::Message& request) {
|
||||
mutex_lock l(mu_);
|
||||
if (request_id == 0) {
|
||||
// For backwards compatibility, allow all requests with request_id 0.
|
||||
return Status::OK();
|
||||
}
|
||||
if (set_.count(request_id) > 0) {
|
||||
|
||||
mutex_lock l(mu_);
|
||||
const bool inserted = set_.insert(request_id).second;
|
||||
if (!inserted) {
|
||||
// Note: RecentRequestIds is not strict LRU because we don't update
|
||||
// request_id's age in the circular_buffer_ if it's tracked again. Strict
|
||||
// LRU is not useful here because returning this error will close the
|
||||
@ -49,7 +53,6 @@ Status RecentRequestIds::TrackUnique(int64 request_id,
|
||||
// when the buffer is not yet full.
|
||||
set_.erase(circular_buffer_[next_index_]);
|
||||
circular_buffer_[next_index_] = request_id;
|
||||
set_.insert(request_id);
|
||||
next_index_ = (next_index_ + 1) % circular_buffer_.size();
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -16,11 +16,13 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RECENT_REQUEST_IDS_H_
|
||||
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RECENT_REQUEST_IDS_H_
|
||||
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/protobuf/worker.pb.h"
|
||||
@ -64,7 +66,7 @@ class RecentRequestIds {
|
||||
// request_id.
|
||||
int next_index_ GUARDED_BY(mu_) = 0;
|
||||
std::vector<int64> circular_buffer_ GUARDED_BY(mu_);
|
||||
gtl::FlatSet<int64> set_ GUARDED_BY(mu_);
|
||||
std::unordered_set<int64> set_ GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -17,8 +17,10 @@ limitations under the License.
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/request_id.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/protobuf/worker.pb.h"
|
||||
|
||||
@ -93,4 +95,15 @@ TEST(RecentRequestIds, Ordered3) { TestOrdered(3); }
|
||||
TEST(RecentRequestIds, Ordered4) { TestOrdered(4); }
|
||||
TEST(RecentRequestIds, Ordered5) { TestOrdered(5); }
|
||||
|
||||
void BM_TrackUnique(int iters) {
|
||||
RecentRequestIds recent_request_ids(100000);
|
||||
RecvTensorRequest request;
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
TF_CHECK_OK(recent_request_ids.TrackUnique(GetUniqueRequestId(),
|
||||
"BM_TrackUnique", request));
|
||||
}
|
||||
}
|
||||
|
||||
BENCHMARK(BM_TrackUnique);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user