Fix cancellation race condition in BaseRendezvousMgr::RegisterCall

PiperOrigin-RevId: 317363743
Change-Id: Ide89dd360a9885b5e8f67b12f362cbce8cb85d80
This commit is contained in:
Haoyu Zhang 2020-06-19 13:00:32 -07:00 committed by TensorFlower Gardener
parent 0869ff0af5
commit bd98ba765a
3 changed files with 37 additions and 47 deletions

View File

@ -139,7 +139,7 @@ Status BaseRemoteRendezvous::Initialize(WorkerSession* session) {
CHECK_NE(session, nullptr) << "session must not be null!";
std::vector<DeferredCall> deferred_calls;
{
mutex_lock l(init_mu_);
mutex_lock l(mu_);
if (session_ != nullptr) {
if (session_->worker_name() == session->worker_name()) {
VLOG(1) << "Skipping rendezvous re-initialization.";
@ -161,12 +161,12 @@ Status BaseRemoteRendezvous::Initialize(WorkerSession* session) {
}
WorkerSession* BaseRemoteRendezvous::session() {
tf_shared_lock l(init_mu_);
tf_shared_lock l(mu_);
return session_;
}
bool BaseRemoteRendezvous::is_initialized() {
tf_shared_lock l(init_mu_);
tf_shared_lock l(mu_);
return is_initialized_locked();
}
@ -176,7 +176,7 @@ Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed,
VLOG(1) << "BaseRemoteRendezvous Send " << this << " " << parsed.FullKey();
WorkerSession* sess = nullptr;
{
tf_shared_lock l(init_mu_);
tf_shared_lock l(mu_);
if (!status_.ok()) return status_;
DCHECK(is_initialized_locked());
sess = session_;
@ -198,7 +198,7 @@ Status BaseRemoteRendezvous::ValidateDevices(const ParsedKey& parsed,
// (e.g. calling session())
WorkerSession* sess = nullptr;
{
tf_shared_lock l(init_mu_);
tf_shared_lock l(mu_);
if (!status_.ok()) return status_;
if (!is_initialized_locked()) {
return errors::Internal("ValidateDevices called before initialization.");
@ -345,7 +345,7 @@ void BaseRemoteRendezvous::RecvLocalAsync(const ParsedKey& parsed,
// Test whether the rendezvous is initialized using a shared lock, to avoid
// the need for exclusive access in the common case.
if (TF_PREDICT_FALSE(!is_initialized())) {
mutex_lock l(init_mu_);
mutex_lock l(mu_);
if (!is_initialized_locked()) {
// RecvLocalAsync can be called (due to an incoming RecvTensor RPC from a
// remote worker) before the RunStep (or PartialRunStep) RPC from the
@ -386,8 +386,7 @@ void BaseRemoteRendezvous::StartAbort(const Status& s) {
local_->StartAbort(derived_status);
{
// Aborts all active RecvTensor calls.
mutex_lock l(init_mu_);
mutex_lock l2(active_mu_);
mutex_lock l(mu_);
if (status_.ok()) {
status_ = derived_status;
for (auto& entry : active_) {
@ -402,42 +401,36 @@ void BaseRemoteRendezvous::StartAbort(const Status& s) {
void BaseRemoteRendezvous::RegisterCall(BaseRecvTensorCall* call,
const Rendezvous::Args& args) {
CancellationManager* cm = args.cancellation_manager;
Status captured_status;
{
tf_shared_lock l(init_mu_);
if (!status_.ok()) {
captured_status = status_;
}
}
if (!captured_status.ok()) {
call->StartAbort(captured_status);
return;
}
bool already_cancelled = false;
InactiveCallback callback = [] {};
if (cm != nullptr) {
auto token = cm->get_cancellation_token();
already_cancelled = !cm->RegisterCallback(token, [this, call] {
{
tf_shared_lock l(active_mu_);
if (active_.find(call) == active_.end()) return;
}
{
mutex_lock l(mu_);
if (!status_.ok()) {
call->StartAbort(status_);
return;
}
if (cm != nullptr) {
auto token = cm->get_cancellation_token();
already_cancelled = !cm->RegisterCallback(token, [this, call] {
{
mutex_lock l(mu_);
if (active_.find(call) == active_.end()) return;
call->StartAbort(
errors::Cancelled("RecvFromRemoteAsync is cancelled."));
}
});
callback = [cm, token] { cm->TryDeregisterCallback(token); };
}
if (already_cancelled) {
call->StartAbort(errors::Cancelled("RecvFromRemoteAsync is cancelled."));
});
callback = [cm, token] { cm->TryDeregisterCallback(token); };
}
if (already_cancelled) {
call->StartAbort(errors::Cancelled("RecvFromRemoteAsync is cancelled."));
} else {
mutex_lock l(active_mu_);
CHECK(active_.emplace(call, callback).second);
} else {
CHECK(active_.emplace(call, callback).second);
}
}
}
void BaseRemoteRendezvous::DeregisterCall(BaseRecvTensorCall* call) {
mutex_lock l(active_mu_);
mutex_lock l(mu_);
auto it = active_.find(call);
if (it != active_.end()) {
// Deregister the cancellation callback, if one was registered.

View File

@ -174,14 +174,12 @@ class BaseRemoteRendezvous : public RemoteRendezvous {
private:
Rendezvous* local_; // Owns a Ref on this object.
// Guards mutable state that is read-mostly after this rendezvous is
// initialized.
mutable mutex init_mu_;
mutable mutex mu_;
// Status given by StartAbort() if any.
Status status_ TF_GUARDED_BY(init_mu_);
Status status_ TF_GUARDED_BY(mu_);
WorkerSession* session_ TF_GUARDED_BY(init_mu_); // Not owned.
WorkerSession* session_ TF_GUARDED_BY(mu_); // Not owned.
// Data structures to handle calls when partially initialized.
struct DeferredCall {
@ -190,16 +188,14 @@ class BaseRemoteRendezvous : public RemoteRendezvous {
DeferredCall(const ParsedKey& parsed, DoneCallback done);
};
std::vector<DeferredCall> deferred_calls_ TF_GUARDED_BY(init_mu_);
std::vector<DeferredCall> deferred_calls_ TF_GUARDED_BY(mu_);
typedef std::function<void()> InactiveCallback;
// Active outstanding RecvTensor calls.
mutex active_mu_;
std::unordered_map<BaseRecvTensorCall*, InactiveCallback> active_
TF_GUARDED_BY(active_mu_);
TF_GUARDED_BY(mu_);
bool is_initialized_locked() TF_SHARED_LOCKS_REQUIRED(init_mu_) {
bool is_initialized_locked() TF_SHARED_LOCKS_REQUIRED(mu_) {
return session_ != nullptr;
}

View File

@ -282,6 +282,7 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync(
// callback.
call->ReleaseWorker(sess->worker_cache());
call->done()(call->status(), Args(), Args(), Tensor(), false);
DeregisterCall(call);
get_call_freelist()->Release(call);
return;
}