diff --git a/tensorflow/core/framework/rendezvous.cc b/tensorflow/core/framework/rendezvous.cc index 5219fb29192..ad3cf912d23 100644 --- a/tensorflow/core/framework/rendezvous.cc +++ b/tensorflow/core/framework/rendezvous.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/notification.h" #include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/manual_constructor.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/logging.h" @@ -146,6 +147,7 @@ Status Rendezvous::Recv(const ParsedKey& key, const Args& args, Tensor* val, return Recv(key, args, val, is_dead, no_timeout); } +namespace { class LocalRendezvousImpl : public Rendezvous { public: explicit LocalRendezvousImpl() {} @@ -153,7 +155,7 @@ class LocalRendezvousImpl : public Rendezvous { Status Send(const ParsedKey& key, const Args& send_args, const Tensor& val, const bool is_dead) override { uint64 key_hash = KeyHash(key.FullKey()); - VLOG(2) << "Send " << this << " " << key_hash << " " << key.FullKey(); + DVLOG(2) << "Send " << this << " " << key_hash << " " << key.FullKey(); mu_.lock(); if (!status_.ok()) { @@ -164,40 +166,36 @@ class LocalRendezvousImpl : public Rendezvous { } ItemQueue* queue = &table_[key_hash]; - if (queue->empty() || queue->front()->IsSendValue()) { + if (queue->head == nullptr || queue->head->type == Item::kSend) { // There is no waiter for this message. Append the message // into the queue. The waiter will pick it up when arrives. // Only send-related fields need to be filled. - VLOG(2) << "Enqueue Send Item (key:" << key.FullKey() << "). "; - Item* item = new Item; - item->value = val; - item->is_dead = is_dead; - item->send_args = send_args; - if (item->send_args.device_context) { - item->send_args.device_context->Ref(); - } - queue->push_back(item); + // TODO(b/143786186): Investigate moving the allocation of `Item` outside + // the lock. + DVLOG(2) << "Enqueue Send Item (key:" << key.FullKey() << "). "; + queue->push_back(new Item(send_args, val, is_dead)); mu_.unlock(); return Status::OK(); } - VLOG(2) << "Consume Recv Item (key:" << key.FullKey() << "). "; + DVLOG(2) << "Consume Recv Item (key:" << key.FullKey() << "). "; // There is an earliest waiter to consume this message. - Item* item = queue->front(); + Item* item = queue->head; // Delete the queue when the last element has been consumed. - if (queue->size() == 1) { - VLOG(2) << "Clean up Send/Recv queue (key:" << key.FullKey() << "). "; + if (item->next == nullptr) { + DVLOG(2) << "Clean up Send/Recv queue (key:" << key.FullKey() << "). "; table_.erase(key_hash); } else { - queue->pop_front(); + queue->head = item->next; } mu_.unlock(); // Notify the waiter by invoking its done closure, outside the // lock. - DCHECK(!item->IsSendValue()); - item->waiter(Status::OK(), send_args, item->recv_args, val, is_dead); + DCHECK_EQ(item->type, Item::kRecv); + (*item->recv_state.waiter)(Status::OK(), send_args, item->args, val, + is_dead); delete item; return Status::OK(); } @@ -205,7 +203,7 @@ class LocalRendezvousImpl : public Rendezvous { void RecvAsync(const ParsedKey& key, const Args& recv_args, DoneCallback done) override { uint64 key_hash = KeyHash(key.FullKey()); - VLOG(2) << "Recv " << this << " " << key_hash << " " << key.FullKey(); + DVLOG(2) << "Recv " << this << " " << key_hash << " " << key.FullKey(); mu_.lock(); if (!status_.ok()) { @@ -217,7 +215,7 @@ class LocalRendezvousImpl : public Rendezvous { } ItemQueue* queue = &table_[key_hash]; - if (queue->empty() || !queue->front()->IsSendValue()) { + if (queue->head == nullptr || queue->head->type == Item::kRecv) { // There is no message to pick up. // Only recv-related fields need to be filled. CancellationManager* cm = recv_args.cancellation_manager; @@ -231,14 +229,29 @@ class LocalRendezvousImpl : public Rendezvous { { mutex_lock l(mu_); ItemQueue* queue = &table_[key_hash]; - if (!queue->empty() && !queue->front()->IsSendValue()) { - for (auto it = queue->begin(); it != queue->end(); it++) { - if ((*it)->cancellation_token == token) { - item = *it; - if (queue->size() == 1) { + // Find an item in the queue with a cancellation token that matches + // `token`, and remove it. + if (queue->head != nullptr && queue->head->type == Item::kRecv) { + for (Item *prev = nullptr, *curr = queue->head; curr != nullptr; + prev = curr, curr = curr->next) { + if (curr->recv_state.cancellation_token == token) { + item = curr; + if (queue->head->next == nullptr) { + // We have a single-element queue, so we can erase it from + // the table. table_.erase(key_hash); } else { - queue->erase(it); + // Remove the current item from the queue. + if (curr == queue->head) { + DCHECK_EQ(prev, nullptr); + queue->head = curr->next; + } else { + DCHECK_NE(prev, nullptr); + prev->next = curr->next; + } + if (queue->tail == curr) { + queue->tail = prev; + } } break; } @@ -247,9 +260,10 @@ class LocalRendezvousImpl : public Rendezvous { } if (item != nullptr) { - item->waiter(StatusGroup::MakeDerived( - errors::Cancelled("RecvAsync is cancelled.")), - Args(), item->recv_args, Tensor(), /*is_dead=*/false); + (*item->recv_state.waiter)( + StatusGroup::MakeDerived( + errors::Cancelled("RecvAsync is cancelled.")), + Args(), item->args, Tensor(), /*is_dead=*/false); delete item; } }); @@ -262,51 +276,49 @@ class LocalRendezvousImpl : public Rendezvous { return; } - VLOG(2) << "Enqueue Recv Item (key:" << key.FullKey() << "). "; - Item* item = new Item; + DVLOG(2) << "Enqueue Recv Item (key:" << key.FullKey() << "). "; + // TODO(b/143786186): Investigate moving the allocation of `Item` outside + // the lock. if (cm != nullptr) { // NOTE(mrry): We must wrap `done` with code that deregisters the // cancellation callback before calling the `done` callback, because the // cancellation manager may no longer be live after `done` is called. - item->waiter = [cm, token, done = std::move(done)]( - const Status& s, const Args& send_args, - const Args& recv_args, const Tensor& v, bool dead) { - cm->TryDeregisterCallback(token); - done(s, send_args, recv_args, v, dead); - }; + queue->push_back(new Item( + recv_args, + [cm, token, done = std::move(done)]( + const Status& s, const Args& send_args, const Args& recv_args, + const Tensor& v, bool dead) { + cm->TryDeregisterCallback(token); + done(s, send_args, recv_args, v, dead); + }, + token)); } else { - item->waiter = std::move(done); + queue->push_back(new Item(recv_args, std::move(done), token)); } - item->recv_args = recv_args; - item->cancellation_token = token; - if (item->recv_args.device_context) { - item->recv_args.device_context->Ref(); - } - queue->push_back(item); mu_.unlock(); return; } - VLOG(2) << "Consume Send Item (key:" << key.FullKey() << "). "; + DVLOG(2) << "Consume Send Item (key:" << key.FullKey() << "). "; // A message has already arrived and is queued in the table under // this key. Consumes the message and invokes the done closure. - Item* item = queue->front(); + Item* item = queue->head; // Delete the queue when the last element has been consumed. - if (queue->size() == 1) { - VLOG(2) << "Clean up Send/Recv queue (key:" << key.FullKey() << "). "; + if (item->next == nullptr) { + DVLOG(2) << "Clean up Send/Recv queue (key:" << key.FullKey() << "). "; table_.erase(key_hash); } else { - queue->pop_front(); + queue->head = item->next; } mu_.unlock(); - // Invokes the done() by invoking its done closure, outside scope - // of the table lock. - DCHECK(item->IsSendValue()); - done(Status::OK(), item->send_args, recv_args, item->value, item->is_dead); + // Invoke done() without holding the table lock. + DCHECK_EQ(item->type, Item::kSend); + done(Status::OK(), item->args, recv_args, *item->send_state.value, + item->send_state.is_dead); delete item; } @@ -319,11 +331,14 @@ class LocalRendezvousImpl : public Rendezvous { table_.swap(table); } for (auto& p : table) { - for (Item* item : p.second) { - if (!item->IsSendValue()) { - item->waiter(status, Args(), Args(), Tensor(), false); + Item* item = p.second.head; + while (item != nullptr) { + if (item->type == Item::kRecv) { + (*item->recv_state.waiter)(status, Args(), Args(), Tensor(), false); } - delete item; + Item* to_delete = item; + item = item->next; + delete to_delete; } } } @@ -331,25 +346,59 @@ class LocalRendezvousImpl : public Rendezvous { private: typedef LocalRendezvousImpl ME; + // Represents a blocked Send() or Recv() call in the rendezvous. struct Item { - DoneCallback waiter = nullptr; - Tensor value; - bool is_dead = false; - Args send_args; - Args recv_args; - CancellationToken cancellation_token; + enum Type { kSend = 0, kRecv = 1 }; + + Item(Args send_args, const Tensor& value, bool is_dead) + : Item(send_args, kSend) { + send_state.value.Init(value); + send_state.is_dead = is_dead; + } + + Item(Args recv_args, DoneCallback waiter, + CancellationToken cancellation_token) + : Item(recv_args, kRecv) { + recv_state.waiter.Init(std::move(waiter)); + recv_state.cancellation_token = cancellation_token; + } ~Item() { - if (send_args.device_context) { - send_args.device_context->Unref(); + if (args.device_context) { + args.device_context->Unref(); } - if (recv_args.device_context) { - recv_args.device_context->Unref(); + if (type == kSend) { + send_state.value.Destroy(); + } else { + recv_state.waiter.Destroy(); } } - // Returns true iff this item represents a value being sent. - bool IsSendValue() const { return this->waiter == nullptr; } + const Args args; + const Type type; + + // Link to next item in an ItemQueue. + Item* next = nullptr; + + // The validity of `send_state` or `recv_state` is determined by `type == + // kSend` or `type == kRecv` respectively. + union { + struct { + ManualConstructor value; + bool is_dead; + } send_state; + struct { + ManualConstructor waiter; + CancellationToken cancellation_token; + } recv_state; + }; + + private: + Item(Args args, Type type) : args(args), type(type) { + if (args.device_context) { + args.device_context->Ref(); + } + } }; // We key the hash table by KeyHash of the Rendezvous::CreateKey string @@ -358,12 +407,25 @@ class LocalRendezvousImpl : public Rendezvous { } // By invariant, the item queue under each key is of the form - // [item.IsSendValue()]* meaning each item is a sent message. + // [item.type == kSend]* meaning each item is a sent message. // or - // [!item.IsSendValue()]* meaning each item is a waiter. - // - // TODO(zhifengc): consider a better queue impl than std::deque. - typedef std::deque ItemQueue; + // [item.type == kRecv]* meaning each item is a waiter. + struct ItemQueue { + void push_back(Item* item) { + if (TF_PREDICT_TRUE(head == nullptr)) { + // The queue is empty. + head = item; + tail = item; + } else { + DCHECK_EQ(tail->type, item->type); + tail->next = item; + tail = item; + } + } + + Item* head = nullptr; + Item* tail = nullptr; + }; typedef gtl::FlatMap Table; // TODO(zhifengc): shard table_. @@ -379,6 +441,7 @@ class LocalRendezvousImpl : public Rendezvous { TF_DISALLOW_COPY_AND_ASSIGN(LocalRendezvousImpl); }; +} // namespace Rendezvous* NewLocalRendezvous() { return new LocalRendezvousImpl(); } diff --git a/tensorflow/core/framework/rendezvous_test.cc b/tensorflow/core/framework/rendezvous_test.cc index da9a1fbbe89..d02d090f32b 100644 --- a/tensorflow/core/framework/rendezvous_test.cc +++ b/tensorflow/core/framework/rendezvous_test.cc @@ -107,13 +107,13 @@ Rendezvous::ParsedKey MakeKey(const string& name) { } const Rendezvous::ParsedKey& KeyFoo() { - static auto key = MakeKey("foo"); - return key; + static auto* key = new Rendezvous::ParsedKey(MakeKey("foo")); + return *key; } const Rendezvous::ParsedKey& KeyBar() { - static auto key = MakeKey("bar"); - return key; + static auto* key = new Rendezvous::ParsedKey(MakeKey("bar")); + return *key; } TEST_F(LocalRendezvousTest, SendRecv) { @@ -451,6 +451,32 @@ void BM_SendRecv(int iters) { } BENCHMARK(BM_SendRecv); +void BM_RecvSend(int iters) { + Rendezvous* rendez = NewLocalRendezvous(); + Tensor orig = V("val"); + Tensor val(DT_STRING, TensorShape({})); + bool is_dead = false; + Rendezvous::Args args; + if (iters > 0) { + while (iters--) { + bool received = false; + rendez->RecvAsync( + KeyFoo(), args, + [&val, &received](const Status& s, const Rendezvous::Args& send_args, + const Rendezvous::Args& recv_args, + const Tensor& tensor, bool is_dead) { + val = tensor; + received = true; + }); + TF_CHECK_OK(rendez->Send(KeyFoo(), args, orig, is_dead)); + CHECK(received); + } + CHECK_EQ(V(val), V(orig)); + } + rendez->Unref(); +} +BENCHMARK(BM_RecvSend); + void BM_PingPong(int iters) { CHECK_GT(iters, 0); auto* cm = new CancellationManager();