Add Ref/Unref to make sure rendezvous outlives its cancellation callback.

Reenable test case on GPU.

PiperOrigin-RevId: 335060618
Change-Id: I59730027765febbc086438401ba2002abe58423f
This commit is contained in:
Haoyu Zhang 2020-10-02 11:03:37 -07:00 committed by TensorFlower Gardener
parent 6015d64eab
commit 0531602921
5 changed files with 39 additions and 8 deletions

View File

@ -151,7 +151,7 @@ void IntraProcessRecvAsyncImpl(const DeviceMgr* device_mgr,
RefCountedIntraProcessRendezvous::RefCountedIntraProcessRendezvous( RefCountedIntraProcessRendezvous::RefCountedIntraProcessRendezvous(
const DeviceMgr* device_mgr) const DeviceMgr* device_mgr)
: device_mgr_(device_mgr) {} : device_mgr_(device_mgr), local_(this) {}
RefCountedIntraProcessRendezvous::~RefCountedIntraProcessRendezvous() {} RefCountedIntraProcessRendezvous::~RefCountedIntraProcessRendezvous() {}
@ -176,7 +176,7 @@ void RefCountedIntraProcessRendezvous::StartAbort(const Status& s) {
PrivateIntraProcessRendezvous::PrivateIntraProcessRendezvous( PrivateIntraProcessRendezvous::PrivateIntraProcessRendezvous(
const DeviceMgr* device_mgr) const DeviceMgr* device_mgr)
: device_mgr_(device_mgr) {} : device_mgr_(device_mgr), local_(nullptr) {}
PrivateIntraProcessRendezvous::~PrivateIntraProcessRendezvous() {} PrivateIntraProcessRendezvous::~PrivateIntraProcessRendezvous() {}

View File

@ -187,6 +187,20 @@ void LocalRendezvous::RecvAsync(const Rendezvous::ParsedKey& key,
CancellationToken token = CancellationManager::kInvalidToken; CancellationToken token = CancellationManager::kInvalidToken;
bool already_cancelled = false; bool already_cancelled = false;
if (cm != nullptr) { if (cm != nullptr) {
// Increment the refcount when cancellation manager is present, to make
// sure the rendezvous outlives the recv and its cancel callbacks.
// This refcount is dropped in exactly one of the following cases:
// (1) Recv registers cancellation callback to cm, and then cm is
// cancelled, unref in the cancellation callback;
// (2) Recv registers cancellation callback to cm, but cm is already
// cancelled, unref in the already_cancelled check;
// (3) Recv is successful, and item done callback finishes deregistering
// the cancellation callback, unref in the item done callback;
// (4) Recv is successful, but the item done callback fails to deregister
// the cancellation callback because cm already StartCancel, in this
// case the cancellation callback will be invoked by the cm anyway,
// unref in the cancellation callback.
if (rc_owner_) rc_owner_->Ref();
token = cm->get_cancellation_token(); token = cm->get_cancellation_token();
already_cancelled = !cm->RegisterCallback(token, [this, token, key_hash] { already_cancelled = !cm->RegisterCallback(token, [this, token, key_hash] {
Item* item = nullptr; Item* item = nullptr;
@ -230,10 +244,14 @@ void LocalRendezvous::RecvAsync(const Rendezvous::ParsedKey& key,
Rendezvous::Args(), item->args, Tensor(), /*is_dead=*/false); Rendezvous::Args(), item->args, Tensor(), /*is_dead=*/false);
delete item; delete item;
} }
// Unref case (1) and (4)
if (rc_owner_) rc_owner_->Unref();
}); });
} }
if (already_cancelled) { if (already_cancelled) {
mu_.unlock(); mu_.unlock();
// Unref case (2)
if (rc_owner_) rc_owner_->Unref();
done(StatusGroup::MakeDerived( done(StatusGroup::MakeDerived(
errors::Cancelled("RecvAsync is cancelled.")), errors::Cancelled("RecvAsync is cancelled.")),
Rendezvous::Args(), recv_args, Tensor(), /*is_dead=*/false); Rendezvous::Args(), recv_args, Tensor(), /*is_dead=*/false);
@ -250,10 +268,17 @@ void LocalRendezvous::RecvAsync(const Rendezvous::ParsedKey& key,
// cancellation manager may no longer be live after `done` is called. // cancellation manager may no longer be live after `done` is called.
queue->push_back(new Item( queue->push_back(new Item(
recv_args, recv_args,
[cm, token, done = std::move(done)]( [this, cm, token, done = std::move(done)](
const Status& s, const Rendezvous::Args& send_args, const Status& s, const Rendezvous::Args& send_args,
const Rendezvous::Args& recv_args, const Tensor& v, bool dead) { const Rendezvous::Args& recv_args, const Tensor& v, bool dead) {
cm->TryDeregisterCallback(token); // TryDeregisterCallback returns true when the cancellation callback
// is successfully deregistered. If it fails because the CM already
// StartAbort, Unref will happen inside the cancellation callback
// when called by the CM.
if (cm->TryDeregisterCallback(token)) {
// Unref case (3)
if (this->rc_owner_) this->rc_owner_->Unref();
}
done(s, send_args, recv_args, v, dead); done(s, send_args, recv_args, v, dead);
}, },
token)); token));

View File

@ -35,7 +35,11 @@ namespace tensorflow {
// is not expected to be needed. // is not expected to be needed.
class LocalRendezvous { class LocalRendezvous {
public: public:
LocalRendezvous() = default; // If the class wrapping LocalRendezvous is refcounted (i.e., extending
// Rendezvous), pass in its pointer in constructor so the LocalRendezvous
// can make sure it outlives the async recv requests.
// Pass in nullptr if the wrapping class is not refcounted.
explicit LocalRendezvous(Rendezvous* owner) : rc_owner_(owner) {}
~LocalRendezvous(); ~LocalRendezvous();
Status Send(const Rendezvous::ParsedKey& key, Status Send(const Rendezvous::ParsedKey& key,
@ -62,6 +66,9 @@ class LocalRendezvous {
typedef gtl::FlatMap<uint64, ItemQueue> Table; typedef gtl::FlatMap<uint64, ItemQueue> Table;
// Pointer to the owner class of this LocalRendezvous if it is refcounted.
const Rendezvous* rc_owner_;
// TODO(zhifengc): shard table_. // TODO(zhifengc): shard table_.
mutex mu_; mutex mu_;
Table table_ TF_GUARDED_BY(mu_); Table table_ TF_GUARDED_BY(mu_);

View File

@ -151,7 +151,7 @@ Status RendezvousInterface::Recv(const ParsedKey& key, const Args& args,
namespace { namespace {
class LocalRendezvousWrapper : public Rendezvous { class LocalRendezvousWrapper : public Rendezvous {
public: public:
LocalRendezvousWrapper() = default; LocalRendezvousWrapper() : impl_(this) {}
Status Send(const ParsedKey& key, const Args& send_args, const Tensor& val, Status Send(const ParsedKey& key, const Args& send_args, const Tensor& val,
const bool is_dead) override { const bool is_dead) override {

View File

@ -3364,8 +3364,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
with self.assertRaises(errors.CancelledError): with self.assertRaises(errors.CancelledError):
cancelable_func() cancelable_func()
# TODO(b/162544929): Enable this test. def testCancelBlockedFunctionExecution(self):
def DISABLE_testCancelBlockedFunctionExecution(self):
if not context.executing_eagerly(): if not context.executing_eagerly():
self.skipTest('eager only') self.skipTest('eager only')