Optimize critical section of LocalRendezvousImpl::{Send,RecvAsync}().

This change makes the following changes to the internal data structure in order to reduce the time spent in the critical section:

1. Replace std::deque with a linked list of items.

Almost all keys [0] in a Rendezvous table map to a single `Item*`. The current std::deque implementation allocates a 4KiB block when we push that item onto the queue, and must free it again when the item is removed. Using a simple linked list halves the number of allocations when enqueuing an item.

2. Merge the recv-specific and send-specific members into a union of structs.

We currently waste time running the default constructor for Tensor (when enqueuing a receive waiter), and space storing separate recv_args and send_args, even though only one is ever used. This change compresses the Item structure from 128 bytes to 80 bytes using a union, and uses ManualConstructor for the union fields to avoid non-trivial default construction costs.

3. Replace Item::IsSendValue() with an enum equality test.

Previously, we would invoke std::function::operator=, which accounted for ~1% of the execution time.

4. Change VLOG to DVLOG in LocalRendezvousImpl.

This matches recent changes to other performance-critical code.

[0] In particular, all Send/Recv ops generated by graph partitioning have this property.

PiperOrigin-RevId: 279963713
Change-Id: If00949ff8ec415c1f4cedb02057413f5e6ae5b7e
This commit is contained in:
Derek Murray 2019-11-12 07:32:45 -08:00 committed by TensorFlower Gardener
parent ae42c8f64c
commit d8749eca03
2 changed files with 167 additions and 78 deletions

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/manual_constructor.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/logging.h"
@ -146,6 +147,7 @@ Status Rendezvous::Recv(const ParsedKey& key, const Args& args, Tensor* val,
return Recv(key, args, val, is_dead, no_timeout);
}
namespace {
class LocalRendezvousImpl : public Rendezvous {
public:
explicit LocalRendezvousImpl() {}
@ -153,7 +155,7 @@ class LocalRendezvousImpl : public Rendezvous {
Status Send(const ParsedKey& key, const Args& send_args, const Tensor& val,
const bool is_dead) override {
uint64 key_hash = KeyHash(key.FullKey());
VLOG(2) << "Send " << this << " " << key_hash << " " << key.FullKey();
DVLOG(2) << "Send " << this << " " << key_hash << " " << key.FullKey();
mu_.lock();
if (!status_.ok()) {
@ -164,40 +166,36 @@ class LocalRendezvousImpl : public Rendezvous {
}
ItemQueue* queue = &table_[key_hash];
if (queue->empty() || queue->front()->IsSendValue()) {
if (queue->head == nullptr || queue->head->type == Item::kSend) {
// There is no waiter for this message. Append the message
// into the queue. The waiter will pick it up when arrives.
// Only send-related fields need to be filled.
VLOG(2) << "Enqueue Send Item (key:" << key.FullKey() << "). ";
Item* item = new Item;
item->value = val;
item->is_dead = is_dead;
item->send_args = send_args;
if (item->send_args.device_context) {
item->send_args.device_context->Ref();
}
queue->push_back(item);
// TODO(b/143786186): Investigate moving the allocation of `Item` outside
// the lock.
DVLOG(2) << "Enqueue Send Item (key:" << key.FullKey() << "). ";
queue->push_back(new Item(send_args, val, is_dead));
mu_.unlock();
return Status::OK();
}
VLOG(2) << "Consume Recv Item (key:" << key.FullKey() << "). ";
DVLOG(2) << "Consume Recv Item (key:" << key.FullKey() << "). ";
// There is an earliest waiter to consume this message.
Item* item = queue->front();
Item* item = queue->head;
// Delete the queue when the last element has been consumed.
if (queue->size() == 1) {
VLOG(2) << "Clean up Send/Recv queue (key:" << key.FullKey() << "). ";
if (item->next == nullptr) {
DVLOG(2) << "Clean up Send/Recv queue (key:" << key.FullKey() << "). ";
table_.erase(key_hash);
} else {
queue->pop_front();
queue->head = item->next;
}
mu_.unlock();
// Notify the waiter by invoking its done closure, outside the
// lock.
DCHECK(!item->IsSendValue());
item->waiter(Status::OK(), send_args, item->recv_args, val, is_dead);
DCHECK_EQ(item->type, Item::kRecv);
(*item->recv_state.waiter)(Status::OK(), send_args, item->args, val,
is_dead);
delete item;
return Status::OK();
}
@ -205,7 +203,7 @@ class LocalRendezvousImpl : public Rendezvous {
void RecvAsync(const ParsedKey& key, const Args& recv_args,
DoneCallback done) override {
uint64 key_hash = KeyHash(key.FullKey());
VLOG(2) << "Recv " << this << " " << key_hash << " " << key.FullKey();
DVLOG(2) << "Recv " << this << " " << key_hash << " " << key.FullKey();
mu_.lock();
if (!status_.ok()) {
@ -217,7 +215,7 @@ class LocalRendezvousImpl : public Rendezvous {
}
ItemQueue* queue = &table_[key_hash];
if (queue->empty() || !queue->front()->IsSendValue()) {
if (queue->head == nullptr || queue->head->type == Item::kRecv) {
// There is no message to pick up.
// Only recv-related fields need to be filled.
CancellationManager* cm = recv_args.cancellation_manager;
@ -231,14 +229,29 @@ class LocalRendezvousImpl : public Rendezvous {
{
mutex_lock l(mu_);
ItemQueue* queue = &table_[key_hash];
if (!queue->empty() && !queue->front()->IsSendValue()) {
for (auto it = queue->begin(); it != queue->end(); it++) {
if ((*it)->cancellation_token == token) {
item = *it;
if (queue->size() == 1) {
// Find an item in the queue with a cancellation token that matches
// `token`, and remove it.
if (queue->head != nullptr && queue->head->type == Item::kRecv) {
for (Item *prev = nullptr, *curr = queue->head; curr != nullptr;
prev = curr, curr = curr->next) {
if (curr->recv_state.cancellation_token == token) {
item = curr;
if (queue->head->next == nullptr) {
// We have a single-element queue, so we can erase it from
// the table.
table_.erase(key_hash);
} else {
queue->erase(it);
// Remove the current item from the queue.
if (curr == queue->head) {
DCHECK_EQ(prev, nullptr);
queue->head = curr->next;
} else {
DCHECK_NE(prev, nullptr);
prev->next = curr->next;
}
if (queue->tail == curr) {
queue->tail = prev;
}
}
break;
}
@ -247,9 +260,10 @@ class LocalRendezvousImpl : public Rendezvous {
}
if (item != nullptr) {
item->waiter(StatusGroup::MakeDerived(
errors::Cancelled("RecvAsync is cancelled.")),
Args(), item->recv_args, Tensor(), /*is_dead=*/false);
(*item->recv_state.waiter)(
StatusGroup::MakeDerived(
errors::Cancelled("RecvAsync is cancelled.")),
Args(), item->args, Tensor(), /*is_dead=*/false);
delete item;
}
});
@ -262,51 +276,49 @@ class LocalRendezvousImpl : public Rendezvous {
return;
}
VLOG(2) << "Enqueue Recv Item (key:" << key.FullKey() << "). ";
Item* item = new Item;
DVLOG(2) << "Enqueue Recv Item (key:" << key.FullKey() << "). ";
// TODO(b/143786186): Investigate moving the allocation of `Item` outside
// the lock.
if (cm != nullptr) {
// NOTE(mrry): We must wrap `done` with code that deregisters the
// cancellation callback before calling the `done` callback, because the
// cancellation manager may no longer be live after `done` is called.
item->waiter = [cm, token, done = std::move(done)](
const Status& s, const Args& send_args,
const Args& recv_args, const Tensor& v, bool dead) {
cm->TryDeregisterCallback(token);
done(s, send_args, recv_args, v, dead);
};
queue->push_back(new Item(
recv_args,
[cm, token, done = std::move(done)](
const Status& s, const Args& send_args, const Args& recv_args,
const Tensor& v, bool dead) {
cm->TryDeregisterCallback(token);
done(s, send_args, recv_args, v, dead);
},
token));
} else {
item->waiter = std::move(done);
queue->push_back(new Item(recv_args, std::move(done), token));
}
item->recv_args = recv_args;
item->cancellation_token = token;
if (item->recv_args.device_context) {
item->recv_args.device_context->Ref();
}
queue->push_back(item);
mu_.unlock();
return;
}
VLOG(2) << "Consume Send Item (key:" << key.FullKey() << "). ";
DVLOG(2) << "Consume Send Item (key:" << key.FullKey() << "). ";
// A message has already arrived and is queued in the table under
// this key. Consumes the message and invokes the done closure.
Item* item = queue->front();
Item* item = queue->head;
// Delete the queue when the last element has been consumed.
if (queue->size() == 1) {
VLOG(2) << "Clean up Send/Recv queue (key:" << key.FullKey() << "). ";
if (item->next == nullptr) {
DVLOG(2) << "Clean up Send/Recv queue (key:" << key.FullKey() << "). ";
table_.erase(key_hash);
} else {
queue->pop_front();
queue->head = item->next;
}
mu_.unlock();
// Invokes the done() by invoking its done closure, outside scope
// of the table lock.
DCHECK(item->IsSendValue());
done(Status::OK(), item->send_args, recv_args, item->value, item->is_dead);
// Invoke done() without holding the table lock.
DCHECK_EQ(item->type, Item::kSend);
done(Status::OK(), item->args, recv_args, *item->send_state.value,
item->send_state.is_dead);
delete item;
}
@ -319,11 +331,14 @@ class LocalRendezvousImpl : public Rendezvous {
table_.swap(table);
}
for (auto& p : table) {
for (Item* item : p.second) {
if (!item->IsSendValue()) {
item->waiter(status, Args(), Args(), Tensor(), false);
Item* item = p.second.head;
while (item != nullptr) {
if (item->type == Item::kRecv) {
(*item->recv_state.waiter)(status, Args(), Args(), Tensor(), false);
}
delete item;
Item* to_delete = item;
item = item->next;
delete to_delete;
}
}
}
@ -331,25 +346,59 @@ class LocalRendezvousImpl : public Rendezvous {
private:
typedef LocalRendezvousImpl ME;
// Represents a blocked Send() or Recv() call in the rendezvous.
struct Item {
DoneCallback waiter = nullptr;
Tensor value;
bool is_dead = false;
Args send_args;
Args recv_args;
CancellationToken cancellation_token;
enum Type { kSend = 0, kRecv = 1 };
Item(Args send_args, const Tensor& value, bool is_dead)
: Item(send_args, kSend) {
send_state.value.Init(value);
send_state.is_dead = is_dead;
}
Item(Args recv_args, DoneCallback waiter,
CancellationToken cancellation_token)
: Item(recv_args, kRecv) {
recv_state.waiter.Init(std::move(waiter));
recv_state.cancellation_token = cancellation_token;
}
~Item() {
if (send_args.device_context) {
send_args.device_context->Unref();
if (args.device_context) {
args.device_context->Unref();
}
if (recv_args.device_context) {
recv_args.device_context->Unref();
if (type == kSend) {
send_state.value.Destroy();
} else {
recv_state.waiter.Destroy();
}
}
// Returns true iff this item represents a value being sent.
bool IsSendValue() const { return this->waiter == nullptr; }
const Args args;
const Type type;
// Link to next item in an ItemQueue.
Item* next = nullptr;
// The validity of `send_state` or `recv_state` is determined by `type ==
// kSend` or `type == kRecv` respectively.
union {
struct {
ManualConstructor<Tensor> value;
bool is_dead;
} send_state;
struct {
ManualConstructor<DoneCallback> waiter;
CancellationToken cancellation_token;
} recv_state;
};
private:
Item(Args args, Type type) : args(args), type(type) {
if (args.device_context) {
args.device_context->Ref();
}
}
};
// We key the hash table by KeyHash of the Rendezvous::CreateKey string
@ -358,12 +407,25 @@ class LocalRendezvousImpl : public Rendezvous {
}
// By invariant, the item queue under each key is of the form
// [item.IsSendValue()]* meaning each item is a sent message.
// [item.type == kSend]* meaning each item is a sent message.
// or
// [!item.IsSendValue()]* meaning each item is a waiter.
//
// TODO(zhifengc): consider a better queue impl than std::deque.
typedef std::deque<Item*> ItemQueue;
// [item.type == kRecv]* meaning each item is a waiter.
struct ItemQueue {
void push_back(Item* item) {
if (TF_PREDICT_TRUE(head == nullptr)) {
// The queue is empty.
head = item;
tail = item;
} else {
DCHECK_EQ(tail->type, item->type);
tail->next = item;
tail = item;
}
}
Item* head = nullptr;
Item* tail = nullptr;
};
typedef gtl::FlatMap<uint64, ItemQueue> Table;
// TODO(zhifengc): shard table_.
@ -379,6 +441,7 @@ class LocalRendezvousImpl : public Rendezvous {
TF_DISALLOW_COPY_AND_ASSIGN(LocalRendezvousImpl);
};
} // namespace
Rendezvous* NewLocalRendezvous() { return new LocalRendezvousImpl(); }

View File

@ -107,13 +107,13 @@ Rendezvous::ParsedKey MakeKey(const string& name) {
}
const Rendezvous::ParsedKey& KeyFoo() {
static auto key = MakeKey("foo");
return key;
static auto* key = new Rendezvous::ParsedKey(MakeKey("foo"));
return *key;
}
const Rendezvous::ParsedKey& KeyBar() {
static auto key = MakeKey("bar");
return key;
static auto* key = new Rendezvous::ParsedKey(MakeKey("bar"));
return *key;
}
TEST_F(LocalRendezvousTest, SendRecv) {
@ -451,6 +451,32 @@ void BM_SendRecv(int iters) {
}
BENCHMARK(BM_SendRecv);
void BM_RecvSend(int iters) {
Rendezvous* rendez = NewLocalRendezvous();
Tensor orig = V("val");
Tensor val(DT_STRING, TensorShape({}));
bool is_dead = false;
Rendezvous::Args args;
if (iters > 0) {
while (iters--) {
bool received = false;
rendez->RecvAsync(
KeyFoo(), args,
[&val, &received](const Status& s, const Rendezvous::Args& send_args,
const Rendezvous::Args& recv_args,
const Tensor& tensor, bool is_dead) {
val = tensor;
received = true;
});
TF_CHECK_OK(rendez->Send(KeyFoo(), args, orig, is_dead));
CHECK(received);
}
CHECK_EQ(V(val), V(orig));
}
rendez->Unref();
}
BENCHMARK(BM_RecvSend);
void BM_PingPong(int iters) {
CHECK_GT(iters, 0);
auto* cm = new CancellationManager();