Add Ref/Unref to make sure rendezvous outlives its cancellation callback.

Reenable test case on GPU.

PiperOrigin-RevId: 335060618
Change-Id: I59730027765febbc086438401ba2002abe58423f
This commit is contained in:
Haoyu Zhang 2020-10-02 11:03:37 -07:00 committed by TensorFlower Gardener
parent 6015d64eab
commit 0531602921
5 changed files with 39 additions and 8 deletions

View File

@ -151,7 +151,7 @@ void IntraProcessRecvAsyncImpl(const DeviceMgr* device_mgr,
RefCountedIntraProcessRendezvous::RefCountedIntraProcessRendezvous(
const DeviceMgr* device_mgr)
: device_mgr_(device_mgr) {}
: device_mgr_(device_mgr), local_(this) {}
RefCountedIntraProcessRendezvous::~RefCountedIntraProcessRendezvous() {}
@ -176,7 +176,7 @@ void RefCountedIntraProcessRendezvous::StartAbort(const Status& s) {
PrivateIntraProcessRendezvous::PrivateIntraProcessRendezvous(
const DeviceMgr* device_mgr)
: device_mgr_(device_mgr) {}
: device_mgr_(device_mgr), local_(nullptr) {}
PrivateIntraProcessRendezvous::~PrivateIntraProcessRendezvous() {}

View File

@ -187,6 +187,20 @@ void LocalRendezvous::RecvAsync(const Rendezvous::ParsedKey& key,
CancellationToken token = CancellationManager::kInvalidToken;
bool already_cancelled = false;
if (cm != nullptr) {
// Increment the refcount when cancellation manager is present, to make
// sure the rendezvous outlives the recv and its cancel callbacks.
// This refcount is dropped in exactly one of the following cases:
// (1) Recv registers cancellation callback to cm, and then cm is
// cancelled, unref in the cancellation callback;
// (2) Recv registers cancellation callback to cm, but cm is already
// cancelled, unref in the already_cancelled check;
// (3) Recv is successful, and item done callback finishes deregistering
// the cancellation callback, unref in the item done callback;
// (4) Recv is successful, but the item done callback fails to deregister
// the cancellation callback because cm already StartCancel, in this
// case the cancellation callback will be invoked by the cm anyway,
// unref in the cancellation callback.
if (rc_owner_) rc_owner_->Ref();
token = cm->get_cancellation_token();
already_cancelled = !cm->RegisterCallback(token, [this, token, key_hash] {
Item* item = nullptr;
@ -230,10 +244,14 @@ void LocalRendezvous::RecvAsync(const Rendezvous::ParsedKey& key,
Rendezvous::Args(), item->args, Tensor(), /*is_dead=*/false);
delete item;
}
// Unref case (1) and (4)
if (rc_owner_) rc_owner_->Unref();
});
}
if (already_cancelled) {
mu_.unlock();
// Unref case (2)
if (rc_owner_) rc_owner_->Unref();
done(StatusGroup::MakeDerived(
errors::Cancelled("RecvAsync is cancelled.")),
Rendezvous::Args(), recv_args, Tensor(), /*is_dead=*/false);
@ -250,10 +268,17 @@ void LocalRendezvous::RecvAsync(const Rendezvous::ParsedKey& key,
// cancellation manager may no longer be live after `done` is called.
queue->push_back(new Item(
recv_args,
[cm, token, done = std::move(done)](
[this, cm, token, done = std::move(done)](
const Status& s, const Rendezvous::Args& send_args,
const Rendezvous::Args& recv_args, const Tensor& v, bool dead) {
cm->TryDeregisterCallback(token);
// TryDeregisterCallback returns true when the cancellation callback
// is successfully deregistered. If it fails because the CM already
// StartAbort, Unref will happen inside the cancellation callback
// when called by the CM.
if (cm->TryDeregisterCallback(token)) {
// Unref case (3)
if (this->rc_owner_) this->rc_owner_->Unref();
}
done(s, send_args, recv_args, v, dead);
},
token));

View File

@ -35,7 +35,11 @@ namespace tensorflow {
// is not expected to be needed.
class LocalRendezvous {
public:
LocalRendezvous() = default;
// If the class wrapping LocalRendezvous is refcounted (i.e., extending
// Rendezvous), pass in its pointer in constructor so the LocalRendezvous
// can make sure it outlives the async recv requests.
// Pass in nullptr if the wrapping class is not refcounted.
explicit LocalRendezvous(Rendezvous* owner) : rc_owner_(owner) {}
~LocalRendezvous();
Status Send(const Rendezvous::ParsedKey& key,
@ -62,6 +66,9 @@ class LocalRendezvous {
typedef gtl::FlatMap<uint64, ItemQueue> Table;
// Pointer to the owner class of this LocalRendezvous if it is refcounted.
const Rendezvous* rc_owner_;
// TODO(zhifengc): shard table_.
mutex mu_;
Table table_ TF_GUARDED_BY(mu_);

View File

@ -151,7 +151,7 @@ Status RendezvousInterface::Recv(const ParsedKey& key, const Args& args,
namespace {
class LocalRendezvousWrapper : public Rendezvous {
public:
LocalRendezvousWrapper() = default;
LocalRendezvousWrapper() : impl_(this) {}
Status Send(const ParsedKey& key, const Args& send_args, const Tensor& val,
const bool is_dead) override {

View File

@ -3364,8 +3364,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
with self.assertRaises(errors.CancelledError):
cancelable_func()
# TODO(b/162544929): Enable this test.
def DISABLE_testCancelBlockedFunctionExecution(self):
def testCancelBlockedFunctionExecution(self):
if not context.executing_eagerly():
self.skipTest('eager only')