Expose the definition of LocalRendezvous to its users.
Previously, creating a LocalRendezvous required a dynamic allocation and refcount manipulation, and invoking it required making virtual method calls. This change allows users (especially IntraProcessRendezvous) to instantiate a LocalRendezvous directly as a member, which avoids these overheads. PiperOrigin-RevId: 282802675 Change-Id: I81d91a2ddd2365bb5612c402db0e1bf66a9ea5f4
This commit is contained in:
parent
8053a43598
commit
235f0e2a89
tensorflow/core
common_runtime
framework
@ -33,16 +33,16 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
|
||||
IntraProcessRendezvous::IntraProcessRendezvous(const DeviceMgr* device_mgr)
|
||||
: device_mgr_(device_mgr), local_(NewLocalRendezvous()) {}
|
||||
: device_mgr_(device_mgr) {}
|
||||
|
||||
IntraProcessRendezvous::~IntraProcessRendezvous() { local_->Unref(); }
|
||||
IntraProcessRendezvous::~IntraProcessRendezvous() {}
|
||||
|
||||
Status IntraProcessRendezvous::Send(const ParsedKey& parsed,
|
||||
Status IntraProcessRendezvous::Send(const ParsedKey& key,
|
||||
const Rendezvous::Args& args,
|
||||
const Tensor& val, const bool is_dead) {
|
||||
VLOG(1) << "IntraProcessRendezvous Send " << this << " " << parsed.FullKey();
|
||||
VLOG(1) << "IntraProcessRendezvous Send " << this << " " << key.FullKey();
|
||||
// Buffers "val" and "device_context" in local_.
|
||||
return local_->Send(parsed, args, val, is_dead);
|
||||
return local_.Send(key, args, val, is_dead);
|
||||
}
|
||||
|
||||
void IntraProcessRendezvous::SameWorkerRecvDone(
|
||||
@ -116,16 +116,16 @@ void IntraProcessRendezvous::SameWorkerRecvDone(
|
||||
out, 0 /*dev_to_dev_stream_index*/, std::move(done), sync_dst_compute);
|
||||
}
|
||||
|
||||
void IntraProcessRendezvous::RecvAsync(const ParsedKey& parsed,
|
||||
const Rendezvous::Args& recv_args,
|
||||
void IntraProcessRendezvous::RecvAsync(const ParsedKey& key,
|
||||
const Rendezvous::Args& args,
|
||||
DoneCallback done) {
|
||||
VLOG(1) << "IntraProcessRendezvous Recv " << this << " " << parsed.FullKey();
|
||||
VLOG(1) << "IntraProcessRendezvous Recv " << this << " " << key.FullKey();
|
||||
|
||||
MEMDEBUG_CACHE_OP("RecvAsync");
|
||||
// Recv the tensor from local_.
|
||||
local_->RecvAsync(
|
||||
parsed, recv_args,
|
||||
[this, parsed, done = std::move(done)](
|
||||
local_.RecvAsync(
|
||||
key, args,
|
||||
[this, key, done = std::move(done)](
|
||||
const Status& status, const Rendezvous::Args& send_args,
|
||||
const Rendezvous::Args& recv_args, const Tensor& in,
|
||||
bool is_dead) mutable {
|
||||
@ -141,7 +141,7 @@ void IntraProcessRendezvous::RecvAsync(const ParsedKey& parsed,
|
||||
};
|
||||
|
||||
if (status.ok() && in.IsInitialized()) {
|
||||
SameWorkerRecvDone(parsed, send_args, recv_args, in, out,
|
||||
SameWorkerRecvDone(key, send_args, recv_args, in, out,
|
||||
std::move(final_callback));
|
||||
} else {
|
||||
final_callback(status);
|
||||
@ -151,7 +151,7 @@ void IntraProcessRendezvous::RecvAsync(const ParsedKey& parsed,
|
||||
|
||||
void IntraProcessRendezvous::StartAbort(const Status& s) {
|
||||
CHECK(!s.ok());
|
||||
local_->StartAbort(s);
|
||||
local_.StartAbort(s);
|
||||
}
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <unordered_map>
|
||||
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/framework/local_rendezvous.h"
|
||||
#include "tensorflow/core/framework/rendezvous.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
@ -31,12 +32,11 @@ namespace tensorflow {
|
||||
|
||||
// IntraProcessRendezvous is a Rendezvous which expects all producers
|
||||
// and consumers to be devices immediately accessible within the
|
||||
// process. That is, it will never be necessary to perform an RPC to
|
||||
// process. That is, it will never be necessary to perform an RPC to
|
||||
// communicate with either.
|
||||
//
|
||||
// Buffering of Tensor values is delegated to a "local" Rendezvous
|
||||
// obtained from NewLocalRendezvous(). This class just adds
|
||||
// functionality to coordinate multiple process-local devices.
|
||||
// Buffering of Tensor values is delegated to a `LocalRendezvous`. This class
|
||||
// just adds functionality to coordinate multiple process-local devices.
|
||||
class IntraProcessRendezvous : public Rendezvous {
|
||||
public:
|
||||
explicit IntraProcessRendezvous(const DeviceMgr* device_mgr);
|
||||
@ -57,7 +57,7 @@ class IntraProcessRendezvous : public Rendezvous {
|
||||
|
||||
private:
|
||||
const DeviceMgr* device_mgr_;
|
||||
Rendezvous* local_; // Owns a Ref on this object.
|
||||
LocalRendezvous local_;
|
||||
|
||||
~IntraProcessRendezvous() override;
|
||||
|
||||
|
300
tensorflow/core/framework/local_rendezvous.cc
Normal file
300
tensorflow/core/framework/local_rendezvous.cc
Normal file
@ -0,0 +1,300 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/local_rendezvous.h"
|
||||
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/notification.h"
|
||||
#include "tensorflow/core/lib/gtl/manual_constructor.h"
|
||||
#include "tensorflow/core/lib/strings/numbers.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Represents a blocked Send() or Recv() call in the rendezvous.
|
||||
struct LocalRendezvous::Item {
|
||||
enum Type { kSend = 0, kRecv = 1 };
|
||||
|
||||
Item(Rendezvous::Args send_args, const Tensor& value, bool is_dead)
|
||||
: Item(send_args, kSend) {
|
||||
send_state.value.Init(value);
|
||||
send_state.is_dead = is_dead;
|
||||
}
|
||||
|
||||
Item(Rendezvous::Args recv_args, Rendezvous::DoneCallback waiter,
|
||||
CancellationToken cancellation_token)
|
||||
: Item(recv_args, kRecv) {
|
||||
recv_state.waiter.Init(std::move(waiter));
|
||||
recv_state.cancellation_token = cancellation_token;
|
||||
}
|
||||
|
||||
~Item() {
|
||||
if (args.device_context) {
|
||||
args.device_context->Unref();
|
||||
}
|
||||
if (type == kSend) {
|
||||
send_state.value.Destroy();
|
||||
} else {
|
||||
recv_state.waiter.Destroy();
|
||||
}
|
||||
}
|
||||
|
||||
const Rendezvous::Args args;
|
||||
const Type type;
|
||||
|
||||
// Link to next item in an ItemQueue.
|
||||
Item* next = nullptr;
|
||||
|
||||
// The validity of `send_state` or `recv_state` is determined by `type ==
|
||||
// kSend` or `type == kRecv` respectively.
|
||||
union {
|
||||
struct {
|
||||
ManualConstructor<Tensor> value;
|
||||
bool is_dead;
|
||||
} send_state;
|
||||
struct {
|
||||
ManualConstructor<Rendezvous::DoneCallback> waiter;
|
||||
CancellationToken cancellation_token;
|
||||
} recv_state;
|
||||
};
|
||||
|
||||
private:
|
||||
Item(Rendezvous::Args args, Type type) : args(args), type(type) {
|
||||
if (args.device_context) {
|
||||
args.device_context->Ref();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
void LocalRendezvous::ItemQueue::push_back(Item* item) {
|
||||
if (TF_PREDICT_TRUE(head == nullptr)) {
|
||||
// The queue is empty.
|
||||
head = item;
|
||||
tail = item;
|
||||
} else {
|
||||
DCHECK_EQ(tail->type, item->type);
|
||||
tail->next = item;
|
||||
tail = item;
|
||||
}
|
||||
}
|
||||
|
||||
LocalRendezvous::~LocalRendezvous() {
|
||||
if (!table_.empty()) {
|
||||
StartAbort(errors::Cancelled("LocalRendezvous deleted"));
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
uint64 KeyHash(const StringPiece& k) { return Hash64(k.data(), k.size()); }
|
||||
} // namespace
|
||||
|
||||
Status LocalRendezvous::Send(const Rendezvous::ParsedKey& key,
|
||||
const Rendezvous::Args& send_args,
|
||||
const Tensor& val, const bool is_dead) {
|
||||
uint64 key_hash = KeyHash(key.FullKey());
|
||||
DVLOG(2) << "Send " << this << " " << key_hash << " " << key.FullKey();
|
||||
|
||||
mu_.lock();
|
||||
if (!status_.ok()) {
|
||||
// Rendezvous has been aborted.
|
||||
Status s = status_;
|
||||
mu_.unlock();
|
||||
return s;
|
||||
}
|
||||
|
||||
ItemQueue* queue = &table_[key_hash];
|
||||
if (queue->head == nullptr || queue->head->type == Item::kSend) {
|
||||
// There is no waiter for this message. Append the message
|
||||
// into the queue. The waiter will pick it up when arrives.
|
||||
// Only send-related fields need to be filled.
|
||||
// TODO(b/143786186): Investigate moving the allocation of `Item` outside
|
||||
// the lock.
|
||||
DVLOG(2) << "Enqueue Send Item (key:" << key.FullKey() << "). ";
|
||||
queue->push_back(new Item(send_args, val, is_dead));
|
||||
mu_.unlock();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DVLOG(2) << "Consume Recv Item (key:" << key.FullKey() << "). ";
|
||||
// There is an earliest waiter to consume this message.
|
||||
Item* item = queue->head;
|
||||
|
||||
// Delete the queue when the last element has been consumed.
|
||||
if (item->next == nullptr) {
|
||||
DVLOG(2) << "Clean up Send/Recv queue (key:" << key.FullKey() << "). ";
|
||||
table_.erase(key_hash);
|
||||
} else {
|
||||
queue->head = item->next;
|
||||
}
|
||||
mu_.unlock();
|
||||
|
||||
// Notify the waiter by invoking its done closure, outside the
|
||||
// lock.
|
||||
DCHECK_EQ(item->type, Item::kRecv);
|
||||
(*item->recv_state.waiter)(Status::OK(), send_args, item->args, val, is_dead);
|
||||
delete item;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void LocalRendezvous::RecvAsync(const Rendezvous::ParsedKey& key,
|
||||
const Rendezvous::Args& recv_args,
|
||||
Rendezvous::DoneCallback done) {
|
||||
uint64 key_hash = KeyHash(key.FullKey());
|
||||
DVLOG(2) << "Recv " << this << " " << key_hash << " " << key.FullKey();
|
||||
|
||||
mu_.lock();
|
||||
if (!status_.ok()) {
|
||||
// Rendezvous has been aborted.
|
||||
Status s = status_;
|
||||
mu_.unlock();
|
||||
done(s, Rendezvous::Args(), recv_args, Tensor(), false);
|
||||
return;
|
||||
}
|
||||
|
||||
ItemQueue* queue = &table_[key_hash];
|
||||
if (queue->head == nullptr || queue->head->type == Item::kRecv) {
|
||||
// There is no message to pick up.
|
||||
// Only recv-related fields need to be filled.
|
||||
CancellationManager* cm = recv_args.cancellation_manager;
|
||||
CancellationToken token = CancellationManager::kInvalidToken;
|
||||
bool already_cancelled = false;
|
||||
if (cm != nullptr) {
|
||||
token = cm->get_cancellation_token();
|
||||
already_cancelled = !cm->RegisterCallback(token, [this, token, key_hash] {
|
||||
Item* item = nullptr;
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
ItemQueue* queue = &table_[key_hash];
|
||||
// Find an item in the queue with a cancellation token that matches
|
||||
// `token`, and remove it.
|
||||
if (queue->head != nullptr && queue->head->type == Item::kRecv) {
|
||||
for (Item *prev = nullptr, *curr = queue->head; curr != nullptr;
|
||||
prev = curr, curr = curr->next) {
|
||||
if (curr->recv_state.cancellation_token == token) {
|
||||
item = curr;
|
||||
if (queue->head->next == nullptr) {
|
||||
// We have a single-element queue, so we can erase it from
|
||||
// the table.
|
||||
table_.erase(key_hash);
|
||||
} else {
|
||||
// Remove the current item from the queue.
|
||||
if (curr == queue->head) {
|
||||
DCHECK_EQ(prev, nullptr);
|
||||
queue->head = curr->next;
|
||||
} else {
|
||||
DCHECK_NE(prev, nullptr);
|
||||
prev->next = curr->next;
|
||||
}
|
||||
if (queue->tail == curr) {
|
||||
queue->tail = prev;
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (item != nullptr) {
|
||||
(*item->recv_state.waiter)(
|
||||
StatusGroup::MakeDerived(
|
||||
errors::Cancelled("RecvAsync is cancelled.")),
|
||||
Rendezvous::Args(), item->args, Tensor(), /*is_dead=*/false);
|
||||
delete item;
|
||||
}
|
||||
});
|
||||
}
|
||||
if (already_cancelled) {
|
||||
mu_.unlock();
|
||||
done(StatusGroup::MakeDerived(
|
||||
errors::Cancelled("RecvAsync is cancelled.")),
|
||||
Rendezvous::Args(), recv_args, Tensor(), /*is_dead=*/false);
|
||||
return;
|
||||
}
|
||||
|
||||
DVLOG(2) << "Enqueue Recv Item (key:" << key.FullKey() << "). ";
|
||||
|
||||
// TODO(b/143786186): Investigate moving the allocation of `Item` outside
|
||||
// the lock.
|
||||
if (cm != nullptr) {
|
||||
// NOTE(mrry): We must wrap `done` with code that deregisters the
|
||||
// cancellation callback before calling the `done` callback, because the
|
||||
// cancellation manager may no longer be live after `done` is called.
|
||||
queue->push_back(new Item(
|
||||
recv_args,
|
||||
[cm, token, done = std::move(done)](
|
||||
const Status& s, const Rendezvous::Args& send_args,
|
||||
const Rendezvous::Args& recv_args, const Tensor& v, bool dead) {
|
||||
cm->TryDeregisterCallback(token);
|
||||
done(s, send_args, recv_args, v, dead);
|
||||
},
|
||||
token));
|
||||
} else {
|
||||
queue->push_back(new Item(recv_args, std::move(done), token));
|
||||
}
|
||||
|
||||
mu_.unlock();
|
||||
return;
|
||||
}
|
||||
|
||||
DVLOG(2) << "Consume Send Item (key:" << key.FullKey() << "). ";
|
||||
// A message has already arrived and is queued in the table under
|
||||
// this key. Consumes the message and invokes the done closure.
|
||||
Item* item = queue->head;
|
||||
|
||||
// Delete the queue when the last element has been consumed.
|
||||
if (item->next == nullptr) {
|
||||
DVLOG(2) << "Clean up Send/Recv queue (key:" << key.FullKey() << "). ";
|
||||
table_.erase(key_hash);
|
||||
} else {
|
||||
queue->head = item->next;
|
||||
}
|
||||
mu_.unlock();
|
||||
|
||||
// Invoke done() without holding the table lock.
|
||||
DCHECK_EQ(item->type, Item::kSend);
|
||||
done(Status::OK(), item->args, recv_args, *item->send_state.value,
|
||||
item->send_state.is_dead);
|
||||
delete item;
|
||||
}
|
||||
|
||||
void LocalRendezvous::StartAbort(const Status& status) {
|
||||
CHECK(!status.ok());
|
||||
Table table;
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
status_.Update(status);
|
||||
table_.swap(table);
|
||||
}
|
||||
for (auto& p : table) {
|
||||
Item* item = p.second.head;
|
||||
while (item != nullptr) {
|
||||
if (item->type == Item::kRecv) {
|
||||
(*item->recv_state.waiter)(status, Rendezvous::Args(),
|
||||
Rendezvous::Args(), Tensor(), false);
|
||||
}
|
||||
Item* to_delete = item;
|
||||
item = item->next;
|
||||
delete to_delete;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
75
tensorflow/core/framework/local_rendezvous.h
Normal file
75
tensorflow/core/framework/local_rendezvous.h
Normal file
@ -0,0 +1,75 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_FRAMEWORK_LOCAL_RENDEZVOUS_H_
|
||||
#define TENSORFLOW_CORE_FRAMEWORK_LOCAL_RENDEZVOUS_H_
|
||||
|
||||
#include "tensorflow/core/framework/rendezvous.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Implements the basic logic of matching Send and Recv operations. See
|
||||
// RendezvousInterface for more details.
|
||||
//
|
||||
// NOTE: Most users will use a class that wraps LocalRendezvous, such as
|
||||
// IntraProcessRendezvous or RemoteRendezvous. This class does not implement
|
||||
// RendezvousInterface because virtual dispatch to LocalRendezvous methods
|
||||
// is not expected to be needed.
|
||||
class LocalRendezvous {
|
||||
public:
|
||||
LocalRendezvous() = default;
|
||||
~LocalRendezvous();
|
||||
|
||||
Status Send(const Rendezvous::ParsedKey& key,
|
||||
const Rendezvous::Args& send_args, const Tensor& val,
|
||||
const bool is_dead);
|
||||
void RecvAsync(const Rendezvous::ParsedKey& key,
|
||||
const Rendezvous::Args& recv_args,
|
||||
Rendezvous::DoneCallback done);
|
||||
void StartAbort(const Status& status);
|
||||
|
||||
private:
|
||||
struct Item;
|
||||
|
||||
// By invariant, the item queue under each key is of the form
|
||||
// [item.type == kSend]* meaning each item is a sent message.
|
||||
// or
|
||||
// [item.type == kRecv]* meaning each item is a waiter.
|
||||
struct ItemQueue {
|
||||
void push_back(Item* item);
|
||||
|
||||
Item* head = nullptr;
|
||||
Item* tail = nullptr;
|
||||
};
|
||||
|
||||
typedef gtl::FlatMap<uint64, ItemQueue> Table;
|
||||
|
||||
// TODO(zhifengc): shard table_.
|
||||
mutex mu_;
|
||||
Table table_ GUARDED_BY(mu_);
|
||||
Status status_ GUARDED_BY(mu_);
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(LocalRendezvous);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_FRAMEWORK_LOCAL_RENDEZVOUS_H_
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/local_rendezvous.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/notification.h"
|
||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||
@ -148,301 +149,29 @@ Status RendezvousInterface::Recv(const ParsedKey& key, const Args& args,
|
||||
}
|
||||
|
||||
namespace {
|
||||
class LocalRendezvousImpl : public Rendezvous {
|
||||
class LocalRendezvousWrapper : public Rendezvous {
|
||||
public:
|
||||
explicit LocalRendezvousImpl() {}
|
||||
LocalRendezvousWrapper() = default;
|
||||
|
||||
Status Send(const ParsedKey& key, const Args& send_args, const Tensor& val,
|
||||
const bool is_dead) override {
|
||||
uint64 key_hash = KeyHash(key.FullKey());
|
||||
DVLOG(2) << "Send " << this << " " << key_hash << " " << key.FullKey();
|
||||
|
||||
mu_.lock();
|
||||
if (!status_.ok()) {
|
||||
// Rendezvous has been aborted.
|
||||
Status s = status_;
|
||||
mu_.unlock();
|
||||
return s;
|
||||
}
|
||||
|
||||
ItemQueue* queue = &table_[key_hash];
|
||||
if (queue->head == nullptr || queue->head->type == Item::kSend) {
|
||||
// There is no waiter for this message. Append the message
|
||||
// into the queue. The waiter will pick it up when arrives.
|
||||
// Only send-related fields need to be filled.
|
||||
// TODO(b/143786186): Investigate moving the allocation of `Item` outside
|
||||
// the lock.
|
||||
DVLOG(2) << "Enqueue Send Item (key:" << key.FullKey() << "). ";
|
||||
queue->push_back(new Item(send_args, val, is_dead));
|
||||
mu_.unlock();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DVLOG(2) << "Consume Recv Item (key:" << key.FullKey() << "). ";
|
||||
// There is an earliest waiter to consume this message.
|
||||
Item* item = queue->head;
|
||||
|
||||
// Delete the queue when the last element has been consumed.
|
||||
if (item->next == nullptr) {
|
||||
DVLOG(2) << "Clean up Send/Recv queue (key:" << key.FullKey() << "). ";
|
||||
table_.erase(key_hash);
|
||||
} else {
|
||||
queue->head = item->next;
|
||||
}
|
||||
mu_.unlock();
|
||||
|
||||
// Notify the waiter by invoking its done closure, outside the
|
||||
// lock.
|
||||
DCHECK_EQ(item->type, Item::kRecv);
|
||||
(*item->recv_state.waiter)(Status::OK(), send_args, item->args, val,
|
||||
is_dead);
|
||||
delete item;
|
||||
return Status::OK();
|
||||
return impl_.Send(key, send_args, val, is_dead);
|
||||
}
|
||||
|
||||
void RecvAsync(const ParsedKey& key, const Args& recv_args,
|
||||
DoneCallback done) override {
|
||||
uint64 key_hash = KeyHash(key.FullKey());
|
||||
DVLOG(2) << "Recv " << this << " " << key_hash << " " << key.FullKey();
|
||||
|
||||
mu_.lock();
|
||||
if (!status_.ok()) {
|
||||
// Rendezvous has been aborted.
|
||||
Status s = status_;
|
||||
mu_.unlock();
|
||||
done(s, Args(), recv_args, Tensor(), false);
|
||||
return;
|
||||
}
|
||||
|
||||
ItemQueue* queue = &table_[key_hash];
|
||||
if (queue->head == nullptr || queue->head->type == Item::kRecv) {
|
||||
// There is no message to pick up.
|
||||
// Only recv-related fields need to be filled.
|
||||
CancellationManager* cm = recv_args.cancellation_manager;
|
||||
CancellationToken token = CancellationManager::kInvalidToken;
|
||||
bool already_cancelled = false;
|
||||
if (cm != nullptr) {
|
||||
token = cm->get_cancellation_token();
|
||||
already_cancelled = !cm->RegisterCallback(token, [this, token,
|
||||
key_hash] {
|
||||
Item* item = nullptr;
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
ItemQueue* queue = &table_[key_hash];
|
||||
// Find an item in the queue with a cancellation token that matches
|
||||
// `token`, and remove it.
|
||||
if (queue->head != nullptr && queue->head->type == Item::kRecv) {
|
||||
for (Item *prev = nullptr, *curr = queue->head; curr != nullptr;
|
||||
prev = curr, curr = curr->next) {
|
||||
if (curr->recv_state.cancellation_token == token) {
|
||||
item = curr;
|
||||
if (queue->head->next == nullptr) {
|
||||
// We have a single-element queue, so we can erase it from
|
||||
// the table.
|
||||
table_.erase(key_hash);
|
||||
} else {
|
||||
// Remove the current item from the queue.
|
||||
if (curr == queue->head) {
|
||||
DCHECK_EQ(prev, nullptr);
|
||||
queue->head = curr->next;
|
||||
} else {
|
||||
DCHECK_NE(prev, nullptr);
|
||||
prev->next = curr->next;
|
||||
}
|
||||
if (queue->tail == curr) {
|
||||
queue->tail = prev;
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (item != nullptr) {
|
||||
(*item->recv_state.waiter)(
|
||||
StatusGroup::MakeDerived(
|
||||
errors::Cancelled("RecvAsync is cancelled.")),
|
||||
Args(), item->args, Tensor(), /*is_dead=*/false);
|
||||
delete item;
|
||||
}
|
||||
});
|
||||
}
|
||||
if (already_cancelled) {
|
||||
mu_.unlock();
|
||||
done(StatusGroup::MakeDerived(
|
||||
errors::Cancelled("RecvAsync is cancelled.")),
|
||||
Args(), recv_args, Tensor(), /*is_dead=*/false);
|
||||
return;
|
||||
}
|
||||
|
||||
DVLOG(2) << "Enqueue Recv Item (key:" << key.FullKey() << "). ";
|
||||
|
||||
// TODO(b/143786186): Investigate moving the allocation of `Item` outside
|
||||
// the lock.
|
||||
if (cm != nullptr) {
|
||||
// NOTE(mrry): We must wrap `done` with code that deregisters the
|
||||
// cancellation callback before calling the `done` callback, because the
|
||||
// cancellation manager may no longer be live after `done` is called.
|
||||
queue->push_back(new Item(
|
||||
recv_args,
|
||||
[cm, token, done = std::move(done)](
|
||||
const Status& s, const Args& send_args, const Args& recv_args,
|
||||
const Tensor& v, bool dead) {
|
||||
cm->TryDeregisterCallback(token);
|
||||
done(s, send_args, recv_args, v, dead);
|
||||
},
|
||||
token));
|
||||
} else {
|
||||
queue->push_back(new Item(recv_args, std::move(done), token));
|
||||
}
|
||||
|
||||
mu_.unlock();
|
||||
return;
|
||||
}
|
||||
|
||||
DVLOG(2) << "Consume Send Item (key:" << key.FullKey() << "). ";
|
||||
// A message has already arrived and is queued in the table under
|
||||
// this key. Consumes the message and invokes the done closure.
|
||||
Item* item = queue->head;
|
||||
|
||||
// Delete the queue when the last element has been consumed.
|
||||
if (item->next == nullptr) {
|
||||
DVLOG(2) << "Clean up Send/Recv queue (key:" << key.FullKey() << "). ";
|
||||
table_.erase(key_hash);
|
||||
} else {
|
||||
queue->head = item->next;
|
||||
}
|
||||
mu_.unlock();
|
||||
|
||||
// Invoke done() without holding the table lock.
|
||||
DCHECK_EQ(item->type, Item::kSend);
|
||||
done(Status::OK(), item->args, recv_args, *item->send_state.value,
|
||||
item->send_state.is_dead);
|
||||
delete item;
|
||||
impl_.RecvAsync(key, recv_args, std::move(done));
|
||||
}
|
||||
|
||||
void StartAbort(const Status& status) override {
|
||||
CHECK(!status.ok());
|
||||
Table table;
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
status_.Update(status);
|
||||
table_.swap(table);
|
||||
}
|
||||
for (auto& p : table) {
|
||||
Item* item = p.second.head;
|
||||
while (item != nullptr) {
|
||||
if (item->type == Item::kRecv) {
|
||||
(*item->recv_state.waiter)(status, Args(), Args(), Tensor(), false);
|
||||
}
|
||||
Item* to_delete = item;
|
||||
item = item->next;
|
||||
delete to_delete;
|
||||
}
|
||||
}
|
||||
}
|
||||
void StartAbort(const Status& status) override { impl_.StartAbort(status); }
|
||||
|
||||
private:
|
||||
typedef LocalRendezvousImpl ME;
|
||||
LocalRendezvous impl_;
|
||||
|
||||
// Represents a blocked Send() or Recv() call in the rendezvous.
|
||||
struct Item {
|
||||
enum Type { kSend = 0, kRecv = 1 };
|
||||
|
||||
Item(Args send_args, const Tensor& value, bool is_dead)
|
||||
: Item(send_args, kSend) {
|
||||
send_state.value.Init(value);
|
||||
send_state.is_dead = is_dead;
|
||||
}
|
||||
|
||||
Item(Args recv_args, DoneCallback waiter,
|
||||
CancellationToken cancellation_token)
|
||||
: Item(recv_args, kRecv) {
|
||||
recv_state.waiter.Init(std::move(waiter));
|
||||
recv_state.cancellation_token = cancellation_token;
|
||||
}
|
||||
|
||||
~Item() {
|
||||
if (args.device_context) {
|
||||
args.device_context->Unref();
|
||||
}
|
||||
if (type == kSend) {
|
||||
send_state.value.Destroy();
|
||||
} else {
|
||||
recv_state.waiter.Destroy();
|
||||
}
|
||||
}
|
||||
|
||||
const Args args;
|
||||
const Type type;
|
||||
|
||||
// Link to next item in an ItemQueue.
|
||||
Item* next = nullptr;
|
||||
|
||||
// The validity of `send_state` or `recv_state` is determined by `type ==
|
||||
// kSend` or `type == kRecv` respectively.
|
||||
union {
|
||||
struct {
|
||||
ManualConstructor<Tensor> value;
|
||||
bool is_dead;
|
||||
} send_state;
|
||||
struct {
|
||||
ManualConstructor<DoneCallback> waiter;
|
||||
CancellationToken cancellation_token;
|
||||
} recv_state;
|
||||
};
|
||||
|
||||
private:
|
||||
Item(Args args, Type type) : args(args), type(type) {
|
||||
if (args.device_context) {
|
||||
args.device_context->Ref();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// We key the hash table by KeyHash of the Rendezvous::CreateKey string
|
||||
static uint64 KeyHash(const StringPiece& k) {
|
||||
return Hash64(k.data(), k.size());
|
||||
}
|
||||
|
||||
// By invariant, the item queue under each key is of the form
|
||||
// [item.type == kSend]* meaning each item is a sent message.
|
||||
// or
|
||||
// [item.type == kRecv]* meaning each item is a waiter.
|
||||
struct ItemQueue {
|
||||
void push_back(Item* item) {
|
||||
if (TF_PREDICT_TRUE(head == nullptr)) {
|
||||
// The queue is empty.
|
||||
head = item;
|
||||
tail = item;
|
||||
} else {
|
||||
DCHECK_EQ(tail->type, item->type);
|
||||
tail->next = item;
|
||||
tail = item;
|
||||
}
|
||||
}
|
||||
|
||||
Item* head = nullptr;
|
||||
Item* tail = nullptr;
|
||||
};
|
||||
typedef gtl::FlatMap<uint64, ItemQueue> Table;
|
||||
|
||||
// TODO(zhifengc): shard table_.
|
||||
mutex mu_;
|
||||
Table table_ GUARDED_BY(mu_);
|
||||
Status status_ GUARDED_BY(mu_);
|
||||
|
||||
~LocalRendezvousImpl() override {
|
||||
if (!table_.empty()) {
|
||||
StartAbort(errors::Cancelled("LocalRendezvousImpl deleted"));
|
||||
}
|
||||
}
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(LocalRendezvousImpl);
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(LocalRendezvousWrapper);
|
||||
};
|
||||
} // namespace
|
||||
|
||||
Rendezvous* NewLocalRendezvous() { return new LocalRendezvousImpl(); }
|
||||
Rendezvous* NewLocalRendezvous() { return new LocalRendezvousWrapper; }
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user