[tf.data service] Fix replay issue in round-robin reads.

Previously we indexed requests by (1) round index and (2) pointer addresses. This is a problem when we get multiple requests for the same round, e.g. due to retries. This could cause the same consumer to be counted multiple times, potentially producing a round of data before all consumers are ready.

PiperOrigin-RevId: 356858844
Change-Id: I2511fabb769bb5211b4d7d9c4185df4f3ce0e4cc
This commit is contained in:
Andrew Audibert 2021-02-10 16:48:35 -08:00 committed by TensorFlower Gardener
parent 4f06d670ff
commit ba45035f3a
2 changed files with 11 additions and 7 deletions
tensorflow/core/data/service

View File

@ -161,8 +161,8 @@ Status RoundRobinTaskRunner::PreparePartialRound()
current_round_ = first_round_;
new_round_cv_.notify_all();
// Indicates that we need a partial round to get consumers back in sync.
auto next_round_request = *(requests_[first_round_ + 1].begin());
if (next_round_request->skipped_previous_round()) {
auto next_round_request = *(requests_[first_round_ + 1].begin()->second);
if (next_round_request.skipped_previous_round()) {
VLOG(1) << "Skipping partial round";
round_skipped_ = true;
return Status::OK();
@ -174,10 +174,13 @@ Status RoundRobinTaskRunner::PreparePartialRound()
Status RoundRobinTaskRunner::PrepareRound(const GetElementRequest& req) {
mutex_lock l(mu_);
absl::flat_hash_set<const GetElementRequest*>& round =
requests_[req.round_index()];
first_round_ = std::min(first_round_, req.round_index());
round.insert(&req);
absl::flat_hash_map<int64, const GetElementRequest*>& round =
requests_[req.round_index()];
round[req.consumer_index()] = &req;
auto cleanup = gtl::MakeCleanup([&]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
requests_[req.round_index()].erase(req.consumer_index());
});
if (current_round_ < req.round_index() && round.size() == num_consumers_) {
current_round_ = req.round_index();
int64 wait_us = kWaitBeforeSkipUs;

View File

@ -150,8 +150,9 @@ class RoundRobinTaskRunner : public TaskRunner {
mutex mu_;
// Condition variable notified whenever we start a new round of round-robin.
condition_variable new_round_cv_;
// Map from round number to requests waiting for data from that round.
absl::flat_hash_map<int64, absl::flat_hash_set<const GetElementRequest*>>
// Outstanding requests, indexed by round number and then consumer index.
absl::flat_hash_map<int64,
absl::flat_hash_map<int64, const GetElementRequest*>>
requests_ TF_GUARDED_BY(mu_);
// Index of the first round we plan to serve. At startup, this is the minimum
// of all requested element indices.