Add Ref/Unref to make sure rendezvous outlives its cancellation callback.
Reenable test case on GPU. PiperOrigin-RevId: 335060618 Change-Id: I59730027765febbc086438401ba2002abe58423f
This commit is contained in:
parent
6015d64eab
commit
0531602921
@ -151,7 +151,7 @@ void IntraProcessRecvAsyncImpl(const DeviceMgr* device_mgr,
|
||||
|
||||
RefCountedIntraProcessRendezvous::RefCountedIntraProcessRendezvous(
|
||||
const DeviceMgr* device_mgr)
|
||||
: device_mgr_(device_mgr) {}
|
||||
: device_mgr_(device_mgr), local_(this) {}
|
||||
|
||||
RefCountedIntraProcessRendezvous::~RefCountedIntraProcessRendezvous() {}
|
||||
|
||||
@ -176,7 +176,7 @@ void RefCountedIntraProcessRendezvous::StartAbort(const Status& s) {
|
||||
|
||||
PrivateIntraProcessRendezvous::PrivateIntraProcessRendezvous(
|
||||
const DeviceMgr* device_mgr)
|
||||
: device_mgr_(device_mgr) {}
|
||||
: device_mgr_(device_mgr), local_(nullptr) {}
|
||||
|
||||
PrivateIntraProcessRendezvous::~PrivateIntraProcessRendezvous() {}
|
||||
|
||||
|
@ -187,6 +187,20 @@ void LocalRendezvous::RecvAsync(const Rendezvous::ParsedKey& key,
|
||||
CancellationToken token = CancellationManager::kInvalidToken;
|
||||
bool already_cancelled = false;
|
||||
if (cm != nullptr) {
|
||||
// Increment the refcount when cancellation manager is present, to make
|
||||
// sure the rendezvous outlives the recv and its cancel callbacks.
|
||||
// This refcount is dropped in exactly one of the following cases:
|
||||
// (1) Recv registers cancellation callback to cm, and then cm is
|
||||
// cancelled, unref in the cancellation callback;
|
||||
// (2) Recv registers cancellation callback to cm, but cm is already
|
||||
// cancelled, unref in the already_cancelled check;
|
||||
// (3) Recv is successful, and item done callback finishes deregistering
|
||||
// the cancellation callback, unref in the item done callback;
|
||||
// (4) Recv is successful, but the item done callback fails to deregister
|
||||
// the cancellation callback because cm already StartCancel, in this
|
||||
// case the cancellation callback will be invoked by the cm anyway,
|
||||
// unref in the cancellation callback.
|
||||
if (rc_owner_) rc_owner_->Ref();
|
||||
token = cm->get_cancellation_token();
|
||||
already_cancelled = !cm->RegisterCallback(token, [this, token, key_hash] {
|
||||
Item* item = nullptr;
|
||||
@ -230,10 +244,14 @@ void LocalRendezvous::RecvAsync(const Rendezvous::ParsedKey& key,
|
||||
Rendezvous::Args(), item->args, Tensor(), /*is_dead=*/false);
|
||||
delete item;
|
||||
}
|
||||
// Unref case (1) and (4)
|
||||
if (rc_owner_) rc_owner_->Unref();
|
||||
});
|
||||
}
|
||||
if (already_cancelled) {
|
||||
mu_.unlock();
|
||||
// Unref case (2)
|
||||
if (rc_owner_) rc_owner_->Unref();
|
||||
done(StatusGroup::MakeDerived(
|
||||
errors::Cancelled("RecvAsync is cancelled.")),
|
||||
Rendezvous::Args(), recv_args, Tensor(), /*is_dead=*/false);
|
||||
@ -250,10 +268,17 @@ void LocalRendezvous::RecvAsync(const Rendezvous::ParsedKey& key,
|
||||
// cancellation manager may no longer be live after `done` is called.
|
||||
queue->push_back(new Item(
|
||||
recv_args,
|
||||
[cm, token, done = std::move(done)](
|
||||
[this, cm, token, done = std::move(done)](
|
||||
const Status& s, const Rendezvous::Args& send_args,
|
||||
const Rendezvous::Args& recv_args, const Tensor& v, bool dead) {
|
||||
cm->TryDeregisterCallback(token);
|
||||
// TryDeregisterCallback returns true when the cancellation callback
|
||||
// is successfully deregistered. If it fails because the CM already
|
||||
// StartAbort, Unref will happen inside the cancellation callback
|
||||
// when called by the CM.
|
||||
if (cm->TryDeregisterCallback(token)) {
|
||||
// Unref case (3)
|
||||
if (this->rc_owner_) this->rc_owner_->Unref();
|
||||
}
|
||||
done(s, send_args, recv_args, v, dead);
|
||||
},
|
||||
token));
|
||||
|
@ -35,7 +35,11 @@ namespace tensorflow {
|
||||
// is not expected to be needed.
|
||||
class LocalRendezvous {
|
||||
public:
|
||||
LocalRendezvous() = default;
|
||||
// If the class wrapping LocalRendezvous is refcounted (i.e., extending
|
||||
// Rendezvous), pass in its pointer in constructor so the LocalRendezvous
|
||||
// can make sure it outlives the async recv requests.
|
||||
// Pass in nullptr if the wrapping class is not refcounted.
|
||||
explicit LocalRendezvous(Rendezvous* owner) : rc_owner_(owner) {}
|
||||
~LocalRendezvous();
|
||||
|
||||
Status Send(const Rendezvous::ParsedKey& key,
|
||||
@ -62,6 +66,9 @@ class LocalRendezvous {
|
||||
|
||||
typedef gtl::FlatMap<uint64, ItemQueue> Table;
|
||||
|
||||
// Pointer to the owner class of this LocalRendezvous if it is refcounted.
|
||||
const Rendezvous* rc_owner_;
|
||||
|
||||
// TODO(zhifengc): shard table_.
|
||||
mutex mu_;
|
||||
Table table_ TF_GUARDED_BY(mu_);
|
||||
|
@ -151,7 +151,7 @@ Status RendezvousInterface::Recv(const ParsedKey& key, const Args& args,
|
||||
namespace {
|
||||
class LocalRendezvousWrapper : public Rendezvous {
|
||||
public:
|
||||
LocalRendezvousWrapper() = default;
|
||||
LocalRendezvousWrapper() : impl_(this) {}
|
||||
|
||||
Status Send(const ParsedKey& key, const Args& send_args, const Tensor& val,
|
||||
const bool is_dead) override {
|
||||
|
@ -3364,8 +3364,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
with self.assertRaises(errors.CancelledError):
|
||||
cancelable_func()
|
||||
|
||||
# TODO(b/162544929): Enable this test.
|
||||
def DISABLE_testCancelBlockedFunctionExecution(self):
|
||||
def testCancelBlockedFunctionExecution(self):
|
||||
if not context.executing_eagerly():
|
||||
self.skipTest('eager only')
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user