Fix cancellation race condition in BaseRendezvousMgr::RegisterCall
PiperOrigin-RevId: 317363743 Change-Id: Ide89dd360a9885b5e8f67b12f362cbce8cb85d80
This commit is contained in:
parent
0869ff0af5
commit
bd98ba765a
tensorflow/core/distributed_runtime
@ -139,7 +139,7 @@ Status BaseRemoteRendezvous::Initialize(WorkerSession* session) {
|
||||
CHECK_NE(session, nullptr) << "session must not be null!";
|
||||
std::vector<DeferredCall> deferred_calls;
|
||||
{
|
||||
mutex_lock l(init_mu_);
|
||||
mutex_lock l(mu_);
|
||||
if (session_ != nullptr) {
|
||||
if (session_->worker_name() == session->worker_name()) {
|
||||
VLOG(1) << "Skipping rendezvous re-initialization.";
|
||||
@ -161,12 +161,12 @@ Status BaseRemoteRendezvous::Initialize(WorkerSession* session) {
|
||||
}
|
||||
|
||||
WorkerSession* BaseRemoteRendezvous::session() {
|
||||
tf_shared_lock l(init_mu_);
|
||||
tf_shared_lock l(mu_);
|
||||
return session_;
|
||||
}
|
||||
|
||||
bool BaseRemoteRendezvous::is_initialized() {
|
||||
tf_shared_lock l(init_mu_);
|
||||
tf_shared_lock l(mu_);
|
||||
return is_initialized_locked();
|
||||
}
|
||||
|
||||
@ -176,7 +176,7 @@ Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed,
|
||||
VLOG(1) << "BaseRemoteRendezvous Send " << this << " " << parsed.FullKey();
|
||||
WorkerSession* sess = nullptr;
|
||||
{
|
||||
tf_shared_lock l(init_mu_);
|
||||
tf_shared_lock l(mu_);
|
||||
if (!status_.ok()) return status_;
|
||||
DCHECK(is_initialized_locked());
|
||||
sess = session_;
|
||||
@ -198,7 +198,7 @@ Status BaseRemoteRendezvous::ValidateDevices(const ParsedKey& parsed,
|
||||
// (e.g. calling session())
|
||||
WorkerSession* sess = nullptr;
|
||||
{
|
||||
tf_shared_lock l(init_mu_);
|
||||
tf_shared_lock l(mu_);
|
||||
if (!status_.ok()) return status_;
|
||||
if (!is_initialized_locked()) {
|
||||
return errors::Internal("ValidateDevices called before initialization.");
|
||||
@ -345,7 +345,7 @@ void BaseRemoteRendezvous::RecvLocalAsync(const ParsedKey& parsed,
|
||||
// Test whether the rendezvous is initialized using a shared lock, to avoid
|
||||
// the need for exclusive access in the common case.
|
||||
if (TF_PREDICT_FALSE(!is_initialized())) {
|
||||
mutex_lock l(init_mu_);
|
||||
mutex_lock l(mu_);
|
||||
if (!is_initialized_locked()) {
|
||||
// RecvLocalAsync can be called (due to an incoming RecvTensor RPC from a
|
||||
// remote worker) before the RunStep (or PartialRunStep) RPC from the
|
||||
@ -386,8 +386,7 @@ void BaseRemoteRendezvous::StartAbort(const Status& s) {
|
||||
local_->StartAbort(derived_status);
|
||||
{
|
||||
// Aborts all active RecvTensor calls.
|
||||
mutex_lock l(init_mu_);
|
||||
mutex_lock l2(active_mu_);
|
||||
mutex_lock l(mu_);
|
||||
if (status_.ok()) {
|
||||
status_ = derived_status;
|
||||
for (auto& entry : active_) {
|
||||
@ -402,42 +401,36 @@ void BaseRemoteRendezvous::StartAbort(const Status& s) {
|
||||
void BaseRemoteRendezvous::RegisterCall(BaseRecvTensorCall* call,
|
||||
const Rendezvous::Args& args) {
|
||||
CancellationManager* cm = args.cancellation_manager;
|
||||
Status captured_status;
|
||||
{
|
||||
tf_shared_lock l(init_mu_);
|
||||
if (!status_.ok()) {
|
||||
captured_status = status_;
|
||||
}
|
||||
}
|
||||
if (!captured_status.ok()) {
|
||||
call->StartAbort(captured_status);
|
||||
return;
|
||||
}
|
||||
|
||||
bool already_cancelled = false;
|
||||
InactiveCallback callback = [] {};
|
||||
if (cm != nullptr) {
|
||||
auto token = cm->get_cancellation_token();
|
||||
already_cancelled = !cm->RegisterCallback(token, [this, call] {
|
||||
{
|
||||
tf_shared_lock l(active_mu_);
|
||||
if (active_.find(call) == active_.end()) return;
|
||||
}
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
if (!status_.ok()) {
|
||||
call->StartAbort(status_);
|
||||
return;
|
||||
}
|
||||
if (cm != nullptr) {
|
||||
auto token = cm->get_cancellation_token();
|
||||
already_cancelled = !cm->RegisterCallback(token, [this, call] {
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
if (active_.find(call) == active_.end()) return;
|
||||
call->StartAbort(
|
||||
errors::Cancelled("RecvFromRemoteAsync is cancelled."));
|
||||
}
|
||||
});
|
||||
callback = [cm, token] { cm->TryDeregisterCallback(token); };
|
||||
}
|
||||
if (already_cancelled) {
|
||||
call->StartAbort(errors::Cancelled("RecvFromRemoteAsync is cancelled."));
|
||||
});
|
||||
callback = [cm, token] { cm->TryDeregisterCallback(token); };
|
||||
}
|
||||
|
||||
if (already_cancelled) {
|
||||
call->StartAbort(errors::Cancelled("RecvFromRemoteAsync is cancelled."));
|
||||
} else {
|
||||
mutex_lock l(active_mu_);
|
||||
CHECK(active_.emplace(call, callback).second);
|
||||
} else {
|
||||
CHECK(active_.emplace(call, callback).second);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void BaseRemoteRendezvous::DeregisterCall(BaseRecvTensorCall* call) {
|
||||
mutex_lock l(active_mu_);
|
||||
mutex_lock l(mu_);
|
||||
auto it = active_.find(call);
|
||||
if (it != active_.end()) {
|
||||
// Deregister the cancellation callback, if one was registered.
|
||||
|
@ -174,14 +174,12 @@ class BaseRemoteRendezvous : public RemoteRendezvous {
|
||||
private:
|
||||
Rendezvous* local_; // Owns a Ref on this object.
|
||||
|
||||
// Guards mutable state that is read-mostly after this rendezvous is
|
||||
// initialized.
|
||||
mutable mutex init_mu_;
|
||||
mutable mutex mu_;
|
||||
|
||||
// Status given by StartAbort() if any.
|
||||
Status status_ TF_GUARDED_BY(init_mu_);
|
||||
Status status_ TF_GUARDED_BY(mu_);
|
||||
|
||||
WorkerSession* session_ TF_GUARDED_BY(init_mu_); // Not owned.
|
||||
WorkerSession* session_ TF_GUARDED_BY(mu_); // Not owned.
|
||||
|
||||
// Data structures to handle calls when partially initialized.
|
||||
struct DeferredCall {
|
||||
@ -190,16 +188,14 @@ class BaseRemoteRendezvous : public RemoteRendezvous {
|
||||
|
||||
DeferredCall(const ParsedKey& parsed, DoneCallback done);
|
||||
};
|
||||
std::vector<DeferredCall> deferred_calls_ TF_GUARDED_BY(init_mu_);
|
||||
std::vector<DeferredCall> deferred_calls_ TF_GUARDED_BY(mu_);
|
||||
|
||||
typedef std::function<void()> InactiveCallback;
|
||||
|
||||
// Active outstanding RecvTensor calls.
|
||||
mutex active_mu_;
|
||||
std::unordered_map<BaseRecvTensorCall*, InactiveCallback> active_
|
||||
TF_GUARDED_BY(active_mu_);
|
||||
TF_GUARDED_BY(mu_);
|
||||
|
||||
bool is_initialized_locked() TF_SHARED_LOCKS_REQUIRED(init_mu_) {
|
||||
bool is_initialized_locked() TF_SHARED_LOCKS_REQUIRED(mu_) {
|
||||
return session_ != nullptr;
|
||||
}
|
||||
|
||||
|
@ -282,6 +282,7 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync(
|
||||
// callback.
|
||||
call->ReleaseWorker(sess->worker_cache());
|
||||
call->done()(call->status(), Args(), Args(), Tensor(), false);
|
||||
DeregisterCall(call);
|
||||
get_call_freelist()->Release(call);
|
||||
return;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user