Add Ref/Unref to make sure rendezvous outlives its cancellation callback.
Reenable test case on GPU. PiperOrigin-RevId: 335060618 Change-Id: I59730027765febbc086438401ba2002abe58423f
This commit is contained in:
parent
6015d64eab
commit
0531602921
tensorflow
core
python/eager
@ -151,7 +151,7 @@ void IntraProcessRecvAsyncImpl(const DeviceMgr* device_mgr,
|
|||||||
|
|
||||||
RefCountedIntraProcessRendezvous::RefCountedIntraProcessRendezvous(
|
RefCountedIntraProcessRendezvous::RefCountedIntraProcessRendezvous(
|
||||||
const DeviceMgr* device_mgr)
|
const DeviceMgr* device_mgr)
|
||||||
: device_mgr_(device_mgr) {}
|
: device_mgr_(device_mgr), local_(this) {}
|
||||||
|
|
||||||
RefCountedIntraProcessRendezvous::~RefCountedIntraProcessRendezvous() {}
|
RefCountedIntraProcessRendezvous::~RefCountedIntraProcessRendezvous() {}
|
||||||
|
|
||||||
@ -176,7 +176,7 @@ void RefCountedIntraProcessRendezvous::StartAbort(const Status& s) {
|
|||||||
|
|
||||||
PrivateIntraProcessRendezvous::PrivateIntraProcessRendezvous(
|
PrivateIntraProcessRendezvous::PrivateIntraProcessRendezvous(
|
||||||
const DeviceMgr* device_mgr)
|
const DeviceMgr* device_mgr)
|
||||||
: device_mgr_(device_mgr) {}
|
: device_mgr_(device_mgr), local_(nullptr) {}
|
||||||
|
|
||||||
PrivateIntraProcessRendezvous::~PrivateIntraProcessRendezvous() {}
|
PrivateIntraProcessRendezvous::~PrivateIntraProcessRendezvous() {}
|
||||||
|
|
||||||
|
@ -187,6 +187,20 @@ void LocalRendezvous::RecvAsync(const Rendezvous::ParsedKey& key,
|
|||||||
CancellationToken token = CancellationManager::kInvalidToken;
|
CancellationToken token = CancellationManager::kInvalidToken;
|
||||||
bool already_cancelled = false;
|
bool already_cancelled = false;
|
||||||
if (cm != nullptr) {
|
if (cm != nullptr) {
|
||||||
|
// Increment the refcount when cancellation manager is present, to make
|
||||||
|
// sure the rendezvous outlives the recv and its cancel callbacks.
|
||||||
|
// This refcount is dropped in exactly one of the following cases:
|
||||||
|
// (1) Recv registers cancellation callback to cm, and then cm is
|
||||||
|
// cancelled, unref in the cancellation callback;
|
||||||
|
// (2) Recv registers cancellation callback to cm, but cm is already
|
||||||
|
// cancelled, unref in the already_cancelled check;
|
||||||
|
// (3) Recv is successful, and item done callback finishes deregistering
|
||||||
|
// the cancellation callback, unref in the item done callback;
|
||||||
|
// (4) Recv is successful, but the item done callback fails to deregister
|
||||||
|
// the cancellation callback because cm already StartCancel, in this
|
||||||
|
// case the cancellation callback will be invoked by the cm anyway,
|
||||||
|
// unref in the cancellation callback.
|
||||||
|
if (rc_owner_) rc_owner_->Ref();
|
||||||
token = cm->get_cancellation_token();
|
token = cm->get_cancellation_token();
|
||||||
already_cancelled = !cm->RegisterCallback(token, [this, token, key_hash] {
|
already_cancelled = !cm->RegisterCallback(token, [this, token, key_hash] {
|
||||||
Item* item = nullptr;
|
Item* item = nullptr;
|
||||||
@ -230,10 +244,14 @@ void LocalRendezvous::RecvAsync(const Rendezvous::ParsedKey& key,
|
|||||||
Rendezvous::Args(), item->args, Tensor(), /*is_dead=*/false);
|
Rendezvous::Args(), item->args, Tensor(), /*is_dead=*/false);
|
||||||
delete item;
|
delete item;
|
||||||
}
|
}
|
||||||
|
// Unref case (1) and (4)
|
||||||
|
if (rc_owner_) rc_owner_->Unref();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
if (already_cancelled) {
|
if (already_cancelled) {
|
||||||
mu_.unlock();
|
mu_.unlock();
|
||||||
|
// Unref case (2)
|
||||||
|
if (rc_owner_) rc_owner_->Unref();
|
||||||
done(StatusGroup::MakeDerived(
|
done(StatusGroup::MakeDerived(
|
||||||
errors::Cancelled("RecvAsync is cancelled.")),
|
errors::Cancelled("RecvAsync is cancelled.")),
|
||||||
Rendezvous::Args(), recv_args, Tensor(), /*is_dead=*/false);
|
Rendezvous::Args(), recv_args, Tensor(), /*is_dead=*/false);
|
||||||
@ -250,10 +268,17 @@ void LocalRendezvous::RecvAsync(const Rendezvous::ParsedKey& key,
|
|||||||
// cancellation manager may no longer be live after `done` is called.
|
// cancellation manager may no longer be live after `done` is called.
|
||||||
queue->push_back(new Item(
|
queue->push_back(new Item(
|
||||||
recv_args,
|
recv_args,
|
||||||
[cm, token, done = std::move(done)](
|
[this, cm, token, done = std::move(done)](
|
||||||
const Status& s, const Rendezvous::Args& send_args,
|
const Status& s, const Rendezvous::Args& send_args,
|
||||||
const Rendezvous::Args& recv_args, const Tensor& v, bool dead) {
|
const Rendezvous::Args& recv_args, const Tensor& v, bool dead) {
|
||||||
cm->TryDeregisterCallback(token);
|
// TryDeregisterCallback returns true when the cancellation callback
|
||||||
|
// is successfully deregistered. If it fails because the CM already
|
||||||
|
// StartAbort, Unref will happen inside the cancellation callback
|
||||||
|
// when called by the CM.
|
||||||
|
if (cm->TryDeregisterCallback(token)) {
|
||||||
|
// Unref case (3)
|
||||||
|
if (this->rc_owner_) this->rc_owner_->Unref();
|
||||||
|
}
|
||||||
done(s, send_args, recv_args, v, dead);
|
done(s, send_args, recv_args, v, dead);
|
||||||
},
|
},
|
||||||
token));
|
token));
|
||||||
|
@ -35,7 +35,11 @@ namespace tensorflow {
|
|||||||
// is not expected to be needed.
|
// is not expected to be needed.
|
||||||
class LocalRendezvous {
|
class LocalRendezvous {
|
||||||
public:
|
public:
|
||||||
LocalRendezvous() = default;
|
// If the class wrapping LocalRendezvous is refcounted (i.e., extending
|
||||||
|
// Rendezvous), pass in its pointer in constructor so the LocalRendezvous
|
||||||
|
// can make sure it outlives the async recv requests.
|
||||||
|
// Pass in nullptr if the wrapping class is not refcounted.
|
||||||
|
explicit LocalRendezvous(Rendezvous* owner) : rc_owner_(owner) {}
|
||||||
~LocalRendezvous();
|
~LocalRendezvous();
|
||||||
|
|
||||||
Status Send(const Rendezvous::ParsedKey& key,
|
Status Send(const Rendezvous::ParsedKey& key,
|
||||||
@ -62,6 +66,9 @@ class LocalRendezvous {
|
|||||||
|
|
||||||
typedef gtl::FlatMap<uint64, ItemQueue> Table;
|
typedef gtl::FlatMap<uint64, ItemQueue> Table;
|
||||||
|
|
||||||
|
// Pointer to the owner class of this LocalRendezvous if it is refcounted.
|
||||||
|
const Rendezvous* rc_owner_;
|
||||||
|
|
||||||
// TODO(zhifengc): shard table_.
|
// TODO(zhifengc): shard table_.
|
||||||
mutex mu_;
|
mutex mu_;
|
||||||
Table table_ TF_GUARDED_BY(mu_);
|
Table table_ TF_GUARDED_BY(mu_);
|
||||||
|
@ -151,7 +151,7 @@ Status RendezvousInterface::Recv(const ParsedKey& key, const Args& args,
|
|||||||
namespace {
|
namespace {
|
||||||
class LocalRendezvousWrapper : public Rendezvous {
|
class LocalRendezvousWrapper : public Rendezvous {
|
||||||
public:
|
public:
|
||||||
LocalRendezvousWrapper() = default;
|
LocalRendezvousWrapper() : impl_(this) {}
|
||||||
|
|
||||||
Status Send(const ParsedKey& key, const Args& send_args, const Tensor& val,
|
Status Send(const ParsedKey& key, const Args& send_args, const Tensor& val,
|
||||||
const bool is_dead) override {
|
const bool is_dead) override {
|
||||||
|
@ -3364,8 +3364,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
|||||||
with self.assertRaises(errors.CancelledError):
|
with self.assertRaises(errors.CancelledError):
|
||||||
cancelable_func()
|
cancelable_func()
|
||||||
|
|
||||||
# TODO(b/162544929): Enable this test.
|
def testCancelBlockedFunctionExecution(self):
|
||||||
def DISABLE_testCancelBlockedFunctionExecution(self):
|
|
||||||
if not context.executing_eagerly():
|
if not context.executing_eagerly():
|
||||||
self.skipTest('eager only')
|
self.skipTest('eager only')
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user