Optimize critical section of LocalRendezvousImpl::{Send,RecvAsync}().
This change makes the following changes to the internal data structure in order to reduce the time spent in the critical section: 1. Replace std::deque with a linked list of items. Almost all keys [0] in a Rendezvous table map to a single `Item*`. The current std::deque implementation allocates a 4KiB block when we push that item onto the queue, and must free it again when the item is removed. Using a simple linked list halves the number of allocations when enqueuing an item. 2. Merge the recv-specific and send-specific members into a union of structs. We currently waste time running the default constructor for Tensor (when enqueuing a receive waiter), and space storing separate recv_args and send_args, even though only one is ever used. This change compresses the Item structure from 128 bytes to 80 bytes using a union, and uses ManualConstructor for the union fields to avoid non-trivial default construction costs. 3. Replace Item::IsSendValue() with an enum equality test. Previously, we would invoke std::function::operator=, which accounted for ~1% of the execution time. 4. Change VLOG to DVLOG in LocalRendezvousImpl. This matches recent changes to other performance-critical code. [0] In particular, all Send/Recv ops generated by graph partitioning have this property. PiperOrigin-RevId: 279963713 Change-Id: If00949ff8ec415c1f4cedb02057413f5e6ae5b7e
This commit is contained in:
parent
ae42c8f64c
commit
d8749eca03
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/notification.h"
|
||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||
#include "tensorflow/core/lib/gtl/manual_constructor.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
@ -146,6 +147,7 @@ Status Rendezvous::Recv(const ParsedKey& key, const Args& args, Tensor* val,
|
||||
return Recv(key, args, val, is_dead, no_timeout);
|
||||
}
|
||||
|
||||
namespace {
|
||||
class LocalRendezvousImpl : public Rendezvous {
|
||||
public:
|
||||
explicit LocalRendezvousImpl() {}
|
||||
@ -153,7 +155,7 @@ class LocalRendezvousImpl : public Rendezvous {
|
||||
Status Send(const ParsedKey& key, const Args& send_args, const Tensor& val,
|
||||
const bool is_dead) override {
|
||||
uint64 key_hash = KeyHash(key.FullKey());
|
||||
VLOG(2) << "Send " << this << " " << key_hash << " " << key.FullKey();
|
||||
DVLOG(2) << "Send " << this << " " << key_hash << " " << key.FullKey();
|
||||
|
||||
mu_.lock();
|
||||
if (!status_.ok()) {
|
||||
@ -164,40 +166,36 @@ class LocalRendezvousImpl : public Rendezvous {
|
||||
}
|
||||
|
||||
ItemQueue* queue = &table_[key_hash];
|
||||
if (queue->empty() || queue->front()->IsSendValue()) {
|
||||
if (queue->head == nullptr || queue->head->type == Item::kSend) {
|
||||
// There is no waiter for this message. Append the message
|
||||
// into the queue. The waiter will pick it up when arrives.
|
||||
// Only send-related fields need to be filled.
|
||||
VLOG(2) << "Enqueue Send Item (key:" << key.FullKey() << "). ";
|
||||
Item* item = new Item;
|
||||
item->value = val;
|
||||
item->is_dead = is_dead;
|
||||
item->send_args = send_args;
|
||||
if (item->send_args.device_context) {
|
||||
item->send_args.device_context->Ref();
|
||||
}
|
||||
queue->push_back(item);
|
||||
// TODO(b/143786186): Investigate moving the allocation of `Item` outside
|
||||
// the lock.
|
||||
DVLOG(2) << "Enqueue Send Item (key:" << key.FullKey() << "). ";
|
||||
queue->push_back(new Item(send_args, val, is_dead));
|
||||
mu_.unlock();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
VLOG(2) << "Consume Recv Item (key:" << key.FullKey() << "). ";
|
||||
DVLOG(2) << "Consume Recv Item (key:" << key.FullKey() << "). ";
|
||||
// There is an earliest waiter to consume this message.
|
||||
Item* item = queue->front();
|
||||
Item* item = queue->head;
|
||||
|
||||
// Delete the queue when the last element has been consumed.
|
||||
if (queue->size() == 1) {
|
||||
VLOG(2) << "Clean up Send/Recv queue (key:" << key.FullKey() << "). ";
|
||||
if (item->next == nullptr) {
|
||||
DVLOG(2) << "Clean up Send/Recv queue (key:" << key.FullKey() << "). ";
|
||||
table_.erase(key_hash);
|
||||
} else {
|
||||
queue->pop_front();
|
||||
queue->head = item->next;
|
||||
}
|
||||
mu_.unlock();
|
||||
|
||||
// Notify the waiter by invoking its done closure, outside the
|
||||
// lock.
|
||||
DCHECK(!item->IsSendValue());
|
||||
item->waiter(Status::OK(), send_args, item->recv_args, val, is_dead);
|
||||
DCHECK_EQ(item->type, Item::kRecv);
|
||||
(*item->recv_state.waiter)(Status::OK(), send_args, item->args, val,
|
||||
is_dead);
|
||||
delete item;
|
||||
return Status::OK();
|
||||
}
|
||||
@ -205,7 +203,7 @@ class LocalRendezvousImpl : public Rendezvous {
|
||||
void RecvAsync(const ParsedKey& key, const Args& recv_args,
|
||||
DoneCallback done) override {
|
||||
uint64 key_hash = KeyHash(key.FullKey());
|
||||
VLOG(2) << "Recv " << this << " " << key_hash << " " << key.FullKey();
|
||||
DVLOG(2) << "Recv " << this << " " << key_hash << " " << key.FullKey();
|
||||
|
||||
mu_.lock();
|
||||
if (!status_.ok()) {
|
||||
@ -217,7 +215,7 @@ class LocalRendezvousImpl : public Rendezvous {
|
||||
}
|
||||
|
||||
ItemQueue* queue = &table_[key_hash];
|
||||
if (queue->empty() || !queue->front()->IsSendValue()) {
|
||||
if (queue->head == nullptr || queue->head->type == Item::kRecv) {
|
||||
// There is no message to pick up.
|
||||
// Only recv-related fields need to be filled.
|
||||
CancellationManager* cm = recv_args.cancellation_manager;
|
||||
@ -231,14 +229,29 @@ class LocalRendezvousImpl : public Rendezvous {
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
ItemQueue* queue = &table_[key_hash];
|
||||
if (!queue->empty() && !queue->front()->IsSendValue()) {
|
||||
for (auto it = queue->begin(); it != queue->end(); it++) {
|
||||
if ((*it)->cancellation_token == token) {
|
||||
item = *it;
|
||||
if (queue->size() == 1) {
|
||||
// Find an item in the queue with a cancellation token that matches
|
||||
// `token`, and remove it.
|
||||
if (queue->head != nullptr && queue->head->type == Item::kRecv) {
|
||||
for (Item *prev = nullptr, *curr = queue->head; curr != nullptr;
|
||||
prev = curr, curr = curr->next) {
|
||||
if (curr->recv_state.cancellation_token == token) {
|
||||
item = curr;
|
||||
if (queue->head->next == nullptr) {
|
||||
// We have a single-element queue, so we can erase it from
|
||||
// the table.
|
||||
table_.erase(key_hash);
|
||||
} else {
|
||||
queue->erase(it);
|
||||
// Remove the current item from the queue.
|
||||
if (curr == queue->head) {
|
||||
DCHECK_EQ(prev, nullptr);
|
||||
queue->head = curr->next;
|
||||
} else {
|
||||
DCHECK_NE(prev, nullptr);
|
||||
prev->next = curr->next;
|
||||
}
|
||||
if (queue->tail == curr) {
|
||||
queue->tail = prev;
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
@ -247,9 +260,10 @@ class LocalRendezvousImpl : public Rendezvous {
|
||||
}
|
||||
|
||||
if (item != nullptr) {
|
||||
item->waiter(StatusGroup::MakeDerived(
|
||||
errors::Cancelled("RecvAsync is cancelled.")),
|
||||
Args(), item->recv_args, Tensor(), /*is_dead=*/false);
|
||||
(*item->recv_state.waiter)(
|
||||
StatusGroup::MakeDerived(
|
||||
errors::Cancelled("RecvAsync is cancelled.")),
|
||||
Args(), item->args, Tensor(), /*is_dead=*/false);
|
||||
delete item;
|
||||
}
|
||||
});
|
||||
@ -262,51 +276,49 @@ class LocalRendezvousImpl : public Rendezvous {
|
||||
return;
|
||||
}
|
||||
|
||||
VLOG(2) << "Enqueue Recv Item (key:" << key.FullKey() << "). ";
|
||||
Item* item = new Item;
|
||||
DVLOG(2) << "Enqueue Recv Item (key:" << key.FullKey() << "). ";
|
||||
|
||||
// TODO(b/143786186): Investigate moving the allocation of `Item` outside
|
||||
// the lock.
|
||||
if (cm != nullptr) {
|
||||
// NOTE(mrry): We must wrap `done` with code that deregisters the
|
||||
// cancellation callback before calling the `done` callback, because the
|
||||
// cancellation manager may no longer be live after `done` is called.
|
||||
item->waiter = [cm, token, done = std::move(done)](
|
||||
const Status& s, const Args& send_args,
|
||||
const Args& recv_args, const Tensor& v, bool dead) {
|
||||
cm->TryDeregisterCallback(token);
|
||||
done(s, send_args, recv_args, v, dead);
|
||||
};
|
||||
queue->push_back(new Item(
|
||||
recv_args,
|
||||
[cm, token, done = std::move(done)](
|
||||
const Status& s, const Args& send_args, const Args& recv_args,
|
||||
const Tensor& v, bool dead) {
|
||||
cm->TryDeregisterCallback(token);
|
||||
done(s, send_args, recv_args, v, dead);
|
||||
},
|
||||
token));
|
||||
} else {
|
||||
item->waiter = std::move(done);
|
||||
queue->push_back(new Item(recv_args, std::move(done), token));
|
||||
}
|
||||
|
||||
item->recv_args = recv_args;
|
||||
item->cancellation_token = token;
|
||||
if (item->recv_args.device_context) {
|
||||
item->recv_args.device_context->Ref();
|
||||
}
|
||||
queue->push_back(item);
|
||||
mu_.unlock();
|
||||
return;
|
||||
}
|
||||
|
||||
VLOG(2) << "Consume Send Item (key:" << key.FullKey() << "). ";
|
||||
DVLOG(2) << "Consume Send Item (key:" << key.FullKey() << "). ";
|
||||
// A message has already arrived and is queued in the table under
|
||||
// this key. Consumes the message and invokes the done closure.
|
||||
Item* item = queue->front();
|
||||
Item* item = queue->head;
|
||||
|
||||
// Delete the queue when the last element has been consumed.
|
||||
if (queue->size() == 1) {
|
||||
VLOG(2) << "Clean up Send/Recv queue (key:" << key.FullKey() << "). ";
|
||||
if (item->next == nullptr) {
|
||||
DVLOG(2) << "Clean up Send/Recv queue (key:" << key.FullKey() << "). ";
|
||||
table_.erase(key_hash);
|
||||
} else {
|
||||
queue->pop_front();
|
||||
queue->head = item->next;
|
||||
}
|
||||
mu_.unlock();
|
||||
|
||||
// Invokes the done() by invoking its done closure, outside scope
|
||||
// of the table lock.
|
||||
DCHECK(item->IsSendValue());
|
||||
done(Status::OK(), item->send_args, recv_args, item->value, item->is_dead);
|
||||
// Invoke done() without holding the table lock.
|
||||
DCHECK_EQ(item->type, Item::kSend);
|
||||
done(Status::OK(), item->args, recv_args, *item->send_state.value,
|
||||
item->send_state.is_dead);
|
||||
delete item;
|
||||
}
|
||||
|
||||
@ -319,11 +331,14 @@ class LocalRendezvousImpl : public Rendezvous {
|
||||
table_.swap(table);
|
||||
}
|
||||
for (auto& p : table) {
|
||||
for (Item* item : p.second) {
|
||||
if (!item->IsSendValue()) {
|
||||
item->waiter(status, Args(), Args(), Tensor(), false);
|
||||
Item* item = p.second.head;
|
||||
while (item != nullptr) {
|
||||
if (item->type == Item::kRecv) {
|
||||
(*item->recv_state.waiter)(status, Args(), Args(), Tensor(), false);
|
||||
}
|
||||
delete item;
|
||||
Item* to_delete = item;
|
||||
item = item->next;
|
||||
delete to_delete;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -331,25 +346,59 @@ class LocalRendezvousImpl : public Rendezvous {
|
||||
private:
|
||||
typedef LocalRendezvousImpl ME;
|
||||
|
||||
// Represents a blocked Send() or Recv() call in the rendezvous.
|
||||
struct Item {
|
||||
DoneCallback waiter = nullptr;
|
||||
Tensor value;
|
||||
bool is_dead = false;
|
||||
Args send_args;
|
||||
Args recv_args;
|
||||
CancellationToken cancellation_token;
|
||||
enum Type { kSend = 0, kRecv = 1 };
|
||||
|
||||
Item(Args send_args, const Tensor& value, bool is_dead)
|
||||
: Item(send_args, kSend) {
|
||||
send_state.value.Init(value);
|
||||
send_state.is_dead = is_dead;
|
||||
}
|
||||
|
||||
Item(Args recv_args, DoneCallback waiter,
|
||||
CancellationToken cancellation_token)
|
||||
: Item(recv_args, kRecv) {
|
||||
recv_state.waiter.Init(std::move(waiter));
|
||||
recv_state.cancellation_token = cancellation_token;
|
||||
}
|
||||
|
||||
~Item() {
|
||||
if (send_args.device_context) {
|
||||
send_args.device_context->Unref();
|
||||
if (args.device_context) {
|
||||
args.device_context->Unref();
|
||||
}
|
||||
if (recv_args.device_context) {
|
||||
recv_args.device_context->Unref();
|
||||
if (type == kSend) {
|
||||
send_state.value.Destroy();
|
||||
} else {
|
||||
recv_state.waiter.Destroy();
|
||||
}
|
||||
}
|
||||
|
||||
// Returns true iff this item represents a value being sent.
|
||||
bool IsSendValue() const { return this->waiter == nullptr; }
|
||||
const Args args;
|
||||
const Type type;
|
||||
|
||||
// Link to next item in an ItemQueue.
|
||||
Item* next = nullptr;
|
||||
|
||||
// The validity of `send_state` or `recv_state` is determined by `type ==
|
||||
// kSend` or `type == kRecv` respectively.
|
||||
union {
|
||||
struct {
|
||||
ManualConstructor<Tensor> value;
|
||||
bool is_dead;
|
||||
} send_state;
|
||||
struct {
|
||||
ManualConstructor<DoneCallback> waiter;
|
||||
CancellationToken cancellation_token;
|
||||
} recv_state;
|
||||
};
|
||||
|
||||
private:
|
||||
Item(Args args, Type type) : args(args), type(type) {
|
||||
if (args.device_context) {
|
||||
args.device_context->Ref();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// We key the hash table by KeyHash of the Rendezvous::CreateKey string
|
||||
@ -358,12 +407,25 @@ class LocalRendezvousImpl : public Rendezvous {
|
||||
}
|
||||
|
||||
// By invariant, the item queue under each key is of the form
|
||||
// [item.IsSendValue()]* meaning each item is a sent message.
|
||||
// [item.type == kSend]* meaning each item is a sent message.
|
||||
// or
|
||||
// [!item.IsSendValue()]* meaning each item is a waiter.
|
||||
//
|
||||
// TODO(zhifengc): consider a better queue impl than std::deque.
|
||||
typedef std::deque<Item*> ItemQueue;
|
||||
// [item.type == kRecv]* meaning each item is a waiter.
|
||||
struct ItemQueue {
|
||||
void push_back(Item* item) {
|
||||
if (TF_PREDICT_TRUE(head == nullptr)) {
|
||||
// The queue is empty.
|
||||
head = item;
|
||||
tail = item;
|
||||
} else {
|
||||
DCHECK_EQ(tail->type, item->type);
|
||||
tail->next = item;
|
||||
tail = item;
|
||||
}
|
||||
}
|
||||
|
||||
Item* head = nullptr;
|
||||
Item* tail = nullptr;
|
||||
};
|
||||
typedef gtl::FlatMap<uint64, ItemQueue> Table;
|
||||
|
||||
// TODO(zhifengc): shard table_.
|
||||
@ -379,6 +441,7 @@ class LocalRendezvousImpl : public Rendezvous {
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(LocalRendezvousImpl);
|
||||
};
|
||||
} // namespace
|
||||
|
||||
Rendezvous* NewLocalRendezvous() { return new LocalRendezvousImpl(); }
|
||||
|
||||
|
@ -107,13 +107,13 @@ Rendezvous::ParsedKey MakeKey(const string& name) {
|
||||
}
|
||||
|
||||
const Rendezvous::ParsedKey& KeyFoo() {
|
||||
static auto key = MakeKey("foo");
|
||||
return key;
|
||||
static auto* key = new Rendezvous::ParsedKey(MakeKey("foo"));
|
||||
return *key;
|
||||
}
|
||||
|
||||
const Rendezvous::ParsedKey& KeyBar() {
|
||||
static auto key = MakeKey("bar");
|
||||
return key;
|
||||
static auto* key = new Rendezvous::ParsedKey(MakeKey("bar"));
|
||||
return *key;
|
||||
}
|
||||
|
||||
TEST_F(LocalRendezvousTest, SendRecv) {
|
||||
@ -451,6 +451,32 @@ void BM_SendRecv(int iters) {
|
||||
}
|
||||
BENCHMARK(BM_SendRecv);
|
||||
|
||||
void BM_RecvSend(int iters) {
|
||||
Rendezvous* rendez = NewLocalRendezvous();
|
||||
Tensor orig = V("val");
|
||||
Tensor val(DT_STRING, TensorShape({}));
|
||||
bool is_dead = false;
|
||||
Rendezvous::Args args;
|
||||
if (iters > 0) {
|
||||
while (iters--) {
|
||||
bool received = false;
|
||||
rendez->RecvAsync(
|
||||
KeyFoo(), args,
|
||||
[&val, &received](const Status& s, const Rendezvous::Args& send_args,
|
||||
const Rendezvous::Args& recv_args,
|
||||
const Tensor& tensor, bool is_dead) {
|
||||
val = tensor;
|
||||
received = true;
|
||||
});
|
||||
TF_CHECK_OK(rendez->Send(KeyFoo(), args, orig, is_dead));
|
||||
CHECK(received);
|
||||
}
|
||||
CHECK_EQ(V(val), V(orig));
|
||||
}
|
||||
rendez->Unref();
|
||||
}
|
||||
BENCHMARK(BM_RecvSend);
|
||||
|
||||
void BM_PingPong(int iters) {
|
||||
CHECK_GT(iters, 0);
|
||||
auto* cm = new CancellationManager();
|
||||
|
Loading…
Reference in New Issue
Block a user