Enable BufRendezvous per-entry cancellation.

This change adds functionality to cancel pending transfers in BufRendezvous via
the CancellationManager.  When a non-null CancellationManager is passed to
ProvideBuf or ConsumeBuf, the function will register a callback with the
cancellation manager.  This callback, called when cancellation starts, will
invoke the pending producer and consumer callbacks on that key in
BufRendezvous, and cleanup state as well.

This will enable finer grained cancellation of collectives.  For example, after
this change it will be possible to pass in the cancellation manager from
OpKernelContext, which will cancel pending transfers without permanently
aborting BufRendezvous.

PiperOrigin-RevId: 338122132
Change-Id: I930736ac3130878c13d29d6527960d9b74f1fb03
This commit is contained in:
Ayush Dubey 2020-10-20 13:12:04 -07:00 committed by TensorFlower Gardener
parent a7d75e16d3
commit 5c90e46f35
18 changed files with 475 additions and 100 deletions

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/notification.h"
@ -51,6 +52,9 @@ void BufRendezvous::StartAbort(const Status& s) {
void BufRendezvous::PurgeTable(const Status& s, HookTable* table) {
for (auto& it : *table) {
Hook* h = it.second;
if (h->cancellation_manager != nullptr) {
h->cancellation_manager->TryDeregisterCallback(h->cancellation_token);
}
if (h->cons_cb != nullptr) {
h->cons_cb(s, nullptr);
}
@ -73,7 +77,8 @@ string BufRendezvous::Hook::DebugString() const {
void BufRendezvous::ProvideBuf(const string& key, Device* dev,
DeviceContext* dev_ctx, const Tensor* v,
const AllocatorAttributes& attr,
const ProducerCallback& done) {
const ProducerCallback& done,
CancellationManager* cancellation_manager) {
Hook* h = nullptr;
Status providebuf_status;
do {
@ -82,9 +87,13 @@ void BufRendezvous::ProvideBuf(const string& key, Device* dev,
providebuf_status = status_;
break;
} else {
CancellationToken cancellation_token = CancellationManager::kInvalidToken;
auto it = hook_table_.find(key);
if (it == hook_table_.end()) {
h = new Hook;
if (cancellation_manager != nullptr) {
cancellation_token = cancellation_manager->get_cancellation_token();
}
h = new Hook(cancellation_manager, cancellation_token);
it = hook_table_.insert(std::make_pair(key, h)).first;
} else {
if (it->second->prod_cb != nullptr) {
@ -100,10 +109,21 @@ void BufRendezvous::ProvideBuf(const string& key, Device* dev,
h->prod_value = v;
h->prod_attr = attr;
h->prod_cb = done;
// If consumer is waiting, kick off right away, removing Hook from table.
if (h->cons_cb != nullptr) {
// If consumer is waiting, kick off right away, removing Hook from
// table.
hook_table_.erase(it);
} else {
if (cancellation_manager != nullptr &&
!cancellation_manager->RegisterCallback(
cancellation_token, [this, key]() { CancelHook(key); })) {
// Register cancellation callback with CancellationManager. If it is
// already cancelled, call done immediately with cancelled status.
providebuf_status = errors::Cancelled(
"Operation was cancelled for BufRendezvous key ", key);
hook_table_.erase(it);
delete h;
}
h = nullptr;
}
}
@ -118,7 +138,8 @@ void BufRendezvous::ProvideBuf(const string& key, Device* dev,
void BufRendezvous::ConsumeBuf(const string& key, const string& device_name,
const uint64 device_incarnation,
const ConsumerCallback& done) {
const ConsumerCallback& done,
CancellationManager* cancellation_manager) {
// Check the incarnation in the request matches the current device
// incarnation of the producer.
Device* device;
@ -157,10 +178,22 @@ void BufRendezvous::ConsumeBuf(const string& key, const string& device_name,
existing_hook->cons_cb = done;
} else {
// Hang consumer callback on the Hook.
Hook* h = new Hook;
hook_table_[key] = h;
h->cons_cb = done;
return;
CancellationToken cancellation_token = CancellationManager::kInvalidToken;
bool already_cancelled = false;
if (cancellation_manager != nullptr) {
cancellation_token = cancellation_manager->get_cancellation_token();
already_cancelled = !cancellation_manager->RegisterCallback(
cancellation_token, [this, key]() { CancelHook(key); });
}
if (already_cancelled) {
consumebuf_status = errors::Cancelled(
"Operation was cancelled for BufRendezvous key ", key);
} else {
Hook* h = new Hook(cancellation_manager, cancellation_token);
h->cons_cb = done;
it = hook_table_.insert(std::make_pair(key, h)).first;
return;
}
}
} while (false);
if (existing_hook) {
@ -173,8 +206,33 @@ void BufRendezvous::ConsumeBuf(const string& key, const string& device_name,
}
}
void BufRendezvous::CancelHook(const string& key) {
Hook* h = nullptr;
{
mutex_lock l(mu_);
auto it = hook_table_.find(key);
if (it == hook_table_.end()) return;
h = it->second;
hook_table_.erase(it);
}
if (h != nullptr) {
auto s = errors::Cancelled("Operation was cancelled for BufRendezvous key ",
key);
if (h->prod_cb != nullptr) {
h->prod_cb(s);
}
if (h->cons_cb != nullptr) {
h->cons_cb(s, /*Hook=*/nullptr);
}
delete h;
}
}
/*static*/
void BufRendezvous::DoneWithHook(Hook* h) {
if (h->cancellation_manager != nullptr) {
h->cancellation_manager->DeregisterCallback(h->cancellation_token);
}
h->prod_cb(Status::OK());
delete h;
}

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "absl/container/flat_hash_map.h"
#include "absl/strings/string_view.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/mutex.h"
@ -66,20 +67,30 @@ class BufRendezvous {
AllocatorAttributes prod_attr;
ProducerCallback prod_cb;
ConsumerCallback cons_cb;
Hook()
CancellationManager* cancellation_manager;
CancellationToken cancellation_token;
explicit Hook(CancellationManager* cancellation_manager,
CancellationToken cancellation_token)
: prod_dev(nullptr),
prod_ctx(nullptr),
prod_value(nullptr),
prod_cb(nullptr),
cons_cb(nullptr) {}
cons_cb(nullptr),
cancellation_manager(cancellation_manager),
cancellation_token(cancellation_token) {}
string DebugString() const;
};
// Called to advertise availability of a Tensor value corresponding
// to key. That value must stay valid until done is called.
//
// If a non-null cancellation manager is provided, this function registers a
// callback to delete the hook and invoke provider/consumer callbacks with
// cancelled error.
void ProvideBuf(const string& key, Device* dev, DeviceContext* dev_ctx,
const Tensor* v, const AllocatorAttributes& attr,
const ProducerCallback& done);
const ProducerCallback& done,
CancellationManager* cancellation_manager);
// Called to request access to a Tensor value corresponding to key.
// Consumer is provided with a Hook as soon as available.
@ -88,8 +99,17 @@ class BufRendezvous {
// `device` that produced this value matches the `incarnation` expected by the
// consumer, and invokes `done` with `FailedPrecondition` status and
// `nullptr` hook if it does not match.
//
// If a non-null cancellation manager is provided, this function registers a
// callback to delete the hook and invoke provider/consumer callbacks with
// cancelled error.
void ConsumeBuf(const string& key, const string& device,
const uint64 incarnation, const ConsumerCallback& done);
const uint64 incarnation, const ConsumerCallback& done,
CancellationManager* cancellation_manager);
// Cancel the rendezvous entry corresponding to `key`. Triggered by the
// cancellation manager. No-op if the rendezvous was already successful.
void CancelHook(const string& key);
// Consumer must call this function when it's done reading the Hook provided
// by the ConsumerCallback. This function will invoke the producer callback

View File

@ -68,6 +68,7 @@ class BufRendezvousTest : public ::testing::Test {
DeviceContext* fake_device_context_;
std::unique_ptr<DeviceMgr> dev_mgr_;
std::unique_ptr<BufRendezvous> br_;
CancellationManager cm_;
static const string* const kDefaultKey;
static const string* const kDefaultDeviceName;
static const uint64 kDefaultIncarnation;
@ -90,19 +91,22 @@ TEST_F(BufRendezvousTest, CorrectUseProducerFirst) {
prod_status = s;
prod_callback_called = true;
note.Notify();
});
},
&cm_);
EXPECT_FALSE(prod_callback_called);
br_->ConsumeBuf(*kDefaultKey, *kDefaultDeviceName, kDefaultIncarnation,
[this, &cons_status, &cons_callback_called](
const Status& s, BufRendezvous::Hook* h) {
cons_status = s;
cons_callback_called = true;
ASSERT_TRUE(h != nullptr);
EXPECT_EQ(h->prod_dev, default_device_);
EXPECT_EQ(h->prod_ctx, fake_device_context_);
EXPECT_EQ(h->prod_value, &a_);
br_->DoneWithHook(h);
});
br_->ConsumeBuf(
*kDefaultKey, *kDefaultDeviceName, kDefaultIncarnation,
[this, &cons_status, &cons_callback_called](const Status& s,
BufRendezvous::Hook* h) {
cons_status = s;
cons_callback_called = true;
ASSERT_TRUE(h != nullptr);
EXPECT_EQ(h->prod_dev, default_device_);
EXPECT_EQ(h->prod_ctx, fake_device_context_);
EXPECT_EQ(h->prod_value, &a_);
br_->DoneWithHook(h);
},
&cm_);
EXPECT_TRUE(cons_callback_called);
note.WaitForNotification();
EXPECT_TRUE(prod_callback_called);
@ -116,17 +120,19 @@ TEST_F(BufRendezvousTest, CorrectUseConsumerFirst) {
bool prod_callback_called = false;
bool cons_callback_called = false;
Notification note;
br_->ConsumeBuf(*kDefaultKey, *kDefaultDeviceName, kDefaultIncarnation,
[this, &cons_status, &cons_callback_called](
const Status& s, BufRendezvous::Hook* h) {
cons_status = s;
cons_callback_called = true;
ASSERT_TRUE(h != nullptr);
EXPECT_EQ(h->prod_dev, default_device_);
EXPECT_EQ(h->prod_ctx, fake_device_context_);
EXPECT_EQ(h->prod_value, &a_);
br_->DoneWithHook(h);
});
br_->ConsumeBuf(
*kDefaultKey, *kDefaultDeviceName, kDefaultIncarnation,
[this, &cons_status, &cons_callback_called](const Status& s,
BufRendezvous::Hook* h) {
cons_status = s;
cons_callback_called = true;
ASSERT_TRUE(h != nullptr);
EXPECT_EQ(h->prod_dev, default_device_);
EXPECT_EQ(h->prod_ctx, fake_device_context_);
EXPECT_EQ(h->prod_value, &a_);
br_->DoneWithHook(h);
},
&cm_);
EXPECT_FALSE(cons_callback_called);
br_->ProvideBuf(
*kDefaultKey, default_device_, fake_device_context_, &a_, aa_,
@ -134,7 +140,8 @@ TEST_F(BufRendezvousTest, CorrectUseConsumerFirst) {
prod_status = s;
prod_callback_called = true;
note.Notify();
});
},
&cm_);
EXPECT_TRUE(cons_callback_called);
note.WaitForNotification();
EXPECT_TRUE(prod_callback_called);
@ -144,17 +151,19 @@ TEST_F(BufRendezvousTest, CorrectUseConsumerFirst) {
TEST_F(BufRendezvousTest, ErrorDuplicatePut) {
bool prod_callback_called = false;
br_->ProvideBuf(*kDefaultKey, default_device_, fake_device_context_, &a_, aa_,
[&prod_callback_called](const Status& s) {
prod_callback_called = true;
});
br_->ProvideBuf(
*kDefaultKey, default_device_, fake_device_context_, &a_, aa_,
[&prod_callback_called](const Status& s) { prod_callback_called = true; },
&cm_);
Status bad_status;
Notification note;
br_->ProvideBuf(*kDefaultKey, default_device_, fake_device_context_, &a_, aa_,
[&bad_status, &note](const Status& s) {
bad_status = s;
note.Notify();
});
br_->ProvideBuf(
*kDefaultKey, default_device_, fake_device_context_, &a_, aa_,
[&bad_status, &note](const Status& s) {
bad_status = s;
note.Notify();
},
&cm_);
note.WaitForNotification();
EXPECT_FALSE(bad_status.ok());
EXPECT_EQ(absl::StrCat("BufRendezvous::ProvideBuf already called for key ",
@ -166,11 +175,13 @@ TEST_F(BufRendezvousTest, ErrorDuplicatePut) {
TEST_F(BufRendezvousTest, ErrorDeleteNonEmpty) {
Status cons_status;
br_->ConsumeBuf(*kDefaultKey, *kDefaultDeviceName, kDefaultIncarnation,
[&cons_status](const Status& s, BufRendezvous::Hook* h) {
cons_status = s;
EXPECT_EQ(h, nullptr);
});
br_->ConsumeBuf(
*kDefaultKey, *kDefaultDeviceName, kDefaultIncarnation,
[&cons_status](const Status& s, BufRendezvous::Hook* h) {
cons_status = s;
EXPECT_EQ(h, nullptr);
},
&cm_);
EXPECT_TRUE(cons_status.ok());
br_.reset();
EXPECT_FALSE(cons_status.ok());
@ -188,12 +199,15 @@ TEST_F(BufRendezvousTest, AbortNonEmpty) {
[&cons_note, &cons_status](const Status& s, BufRendezvous::Hook* h) {
cons_status = s;
cons_note.Notify();
});
br_->ProvideBuf("key1", default_device_, fake_device_context_, &a_, aa_,
[&prod_note, &prod_status](const Status& s) {
prod_status = s;
prod_note.Notify();
});
},
&cm_);
br_->ProvideBuf(
"key1", default_device_, fake_device_context_, &a_, aa_,
[&prod_note, &prod_status](const Status& s) {
prod_status = s;
prod_note.Notify();
},
&cm_);
br_->StartAbort(errors::Internal("Falling sky detected"));
prod_note.WaitForNotification();
cons_note.WaitForNotification();
@ -218,12 +232,15 @@ TEST_F(BufRendezvousTest, UseAfterAbort) {
[&cons_note, &cons_status](const Status& s, BufRendezvous::Hook* h) {
cons_status = s;
cons_note.Notify();
});
br_->ProvideBuf("key1", default_device_, fake_device_context_, &a_, aa_,
[&prod_note, &prod_status](const Status& s) {
prod_status = s;
prod_note.Notify();
});
},
&cm_);
br_->ProvideBuf(
"key1", default_device_, fake_device_context_, &a_, aa_,
[&prod_note, &prod_status](const Status& s) {
prod_status = s;
prod_note.Notify();
},
&cm_);
prod_note.WaitForNotification();
cons_note.WaitForNotification();
EXPECT_FALSE(prod_status.ok());
@ -237,18 +254,161 @@ TEST_F(BufRendezvousTest, UseAfterAbort) {
TEST_F(BufRendezvousTest, DeviceIncarnationMismatch) {
Status cons_status;
Notification note;
br_->ProvideBuf(*kDefaultKey, default_device_, fake_device_context_, &a_, aa_,
[](const Status&) {});
br_->ProvideBuf(
*kDefaultKey, default_device_, fake_device_context_, &a_, aa_,
[](const Status&) {}, /*cancellation_manager=*/nullptr);
const uint64 incorrect_incarnation = 23456;
br_->ConsumeBuf(
*kDefaultKey, *kDefaultDeviceName, incorrect_incarnation,
[&note, &cons_status](const Status& s, BufRendezvous::Hook* h) {
cons_status = s;
note.Notify();
});
},
/*cancellation_manager=*/nullptr);
note.WaitForNotification();
EXPECT_TRUE(errors::IsFailedPrecondition(cons_status));
}
TEST_F(BufRendezvousTest, ProvideThenCancel) {
Status status;
Notification note;
br_->ProvideBuf(
*kDefaultKey, default_device_, fake_device_context_, &a_, aa_,
[&status, &note](const Status& s) {
status = s;
note.Notify();
},
&cm_);
cm_.StartCancel();
note.WaitForNotification();
EXPECT_TRUE(errors::IsCancelled(status));
EXPECT_NE(
status.error_message().find(absl::StrCat(
"Operation was cancelled for BufRendezvous key ", *kDefaultKey)),
string::npos);
}
TEST_F(BufRendezvousTest, CancelThenProvide) {
Status status;
Notification note;
cm_.StartCancel();
br_->ProvideBuf(
*kDefaultKey, default_device_, fake_device_context_, &a_, aa_,
[&status, &note](const Status& s) {
status = s;
note.Notify();
},
&cm_);
note.WaitForNotification();
EXPECT_TRUE(errors::IsCancelled(status));
EXPECT_NE(
status.error_message().find(absl::StrCat(
"Operation was cancelled for BufRendezvous key ", *kDefaultKey)),
string::npos);
}
TEST_F(BufRendezvousTest, ConsumeThenCancel) {
Status status;
Notification note;
br_->ConsumeBuf(
*kDefaultKey, *kDefaultDeviceName, kDefaultIncarnation,
[&status, &note](const Status& s, BufRendezvous::Hook* h) {
status = s;
note.Notify();
},
&cm_);
cm_.StartCancel();
note.WaitForNotification();
EXPECT_TRUE(errors::IsCancelled(status));
EXPECT_NE(
status.error_message().find(absl::StrCat(
"Operation was cancelled for BufRendezvous key ", *kDefaultKey)),
string::npos);
}
TEST_F(BufRendezvousTest, CancelThenConsume) {
Status status;
Notification note;
cm_.StartCancel();
br_->ConsumeBuf(
*kDefaultKey, *kDefaultDeviceName, kDefaultIncarnation,
[&status, &note](const Status& s, BufRendezvous::Hook* h) {
status = s;
note.Notify();
},
&cm_);
note.WaitForNotification();
EXPECT_TRUE(errors::IsCancelled(status));
EXPECT_NE(
status.error_message().find(absl::StrCat(
"Operation was cancelled for BufRendezvous key ", *kDefaultKey)),
string::npos);
}
TEST_F(BufRendezvousTest, ProvideConsumeThenCancel) {
Status prod_status;
Status cons_status;
bool prod_callback_called = false;
bool cons_callback_called = false;
Notification note;
br_->ProvideBuf(
*kDefaultKey, default_device_, fake_device_context_, &a_, aa_,
[&note, &prod_status, &prod_callback_called](const Status& s) {
prod_status = s;
prod_callback_called = true;
note.Notify();
},
&cm_);
EXPECT_FALSE(prod_callback_called);
br_->ConsumeBuf(
*kDefaultKey, *kDefaultDeviceName, kDefaultIncarnation,
[this, &cons_status, &cons_callback_called](const Status& s,
BufRendezvous::Hook* h) {
cons_status = s;
cons_callback_called = true;
ASSERT_TRUE(h != nullptr);
EXPECT_EQ(h->prod_dev, default_device_);
EXPECT_EQ(h->prod_ctx, fake_device_context_);
EXPECT_EQ(h->prod_value, &a_);
br_->DoneWithHook(h);
},
&cm_);
note.WaitForNotification();
cm_.StartCancel();
EXPECT_TRUE(cons_callback_called);
EXPECT_TRUE(prod_callback_called);
TF_EXPECT_OK(cons_status);
TF_EXPECT_OK(prod_status);
}
TEST_F(BufRendezvousTest, CancelThenProvideConsume) {
Status prod_status;
Status cons_status;
bool prod_callback_called = false;
bool cons_callback_called = false;
cm_.StartCancel();
br_->ProvideBuf(
*kDefaultKey, default_device_, fake_device_context_, &a_, aa_,
[&prod_status, &prod_callback_called](const Status& s) {
prod_status = s;
EXPECT_TRUE(errors::IsCancelled(prod_status));
prod_callback_called = true;
},
&cm_);
EXPECT_TRUE(prod_callback_called);
EXPECT_TRUE(errors::IsCancelled(prod_status));
br_->ConsumeBuf(
*kDefaultKey, *kDefaultDeviceName, kDefaultIncarnation,
[&cons_status, &cons_callback_called](const Status& s,
BufRendezvous::Hook* h) {
cons_status = s;
EXPECT_TRUE(errors::IsCancelled(cons_status));
cons_callback_called = true;
},
&cm_);
EXPECT_TRUE(cons_callback_called);
EXPECT_TRUE(errors::IsCancelled(cons_status));
}
} // namespace
} // namespace tensorflow

View File

@ -28,7 +28,7 @@ void CollectiveRemoteAccessLocal::RecvFromPeer(
const string& key, Device* to_device, DeviceContext* to_device_ctx,
const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor,
const DeviceLocality& client_locality, int dev_to_dev_stream_index,
const StatusCallback& done) {
CancellationManager* cancellation_manager, const StatusCallback& done) {
VLOG(1) << "RecvFromPeer " << this << " from " << peer_device << " key "
<< key;
if (!peer_is_local) {
@ -91,18 +91,19 @@ void CollectiveRemoteAccessLocal::RecvFromPeer(
};
buf_rendezvous_.ConsumeBuf(key, from_device->name(),
from_device->attributes().incarnation(),
consumer_callback);
consumer_callback, cancellation_manager);
}
void CollectiveRemoteAccessLocal::PostToPeer(
const string& peer_device, const string& peer_task, const string& key,
Device* from_device, DeviceContext* from_device_ctx,
const AllocatorAttributes& from_alloc_attr, const Tensor* from_tensor,
const DeviceLocality& client_locality, const StatusCallback& done) {
const DeviceLocality& client_locality,
CancellationManager* cancellation_manager, const StatusCallback& done) {
VLOG(1) << "PostToPeer " << this << " key " << key
<< " step_id_=" << step_id_;
buf_rendezvous_.ProvideBuf(key, from_device, from_device_ctx, from_tensor,
from_alloc_attr, done);
from_alloc_attr, done, cancellation_manager);
}
void CollectiveRemoteAccessLocal::CheckPeerHealth(const string& peer_task,

View File

@ -43,6 +43,7 @@ class CollectiveRemoteAccessLocal : public CollectiveRemoteAccess {
const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor,
const DeviceLocality& client_locality,
int dev_to_dev_stream_index,
CancellationManager* cancellation_manager,
const StatusCallback& done) override;
void PostToPeer(const string& peer_device, const string& peer_task,
@ -51,6 +52,7 @@ class CollectiveRemoteAccessLocal : public CollectiveRemoteAccess {
const AllocatorAttributes& from_alloc_attr,
const Tensor* from_tensor,
const DeviceLocality& client_locality,
CancellationManager* cancellation_manager,
const StatusCallback& done) override;
void CheckPeerHealth(const string& peer_task,

View File

@ -52,6 +52,7 @@ class CollectiveRemoteAccessLocalTest : public ::testing::Test {
cp, device_mgr_.get(), drl_.get(), kTaskName);
rma_ = absl::make_unique<CollectiveRemoteAccessLocal>(device_mgr_.get(),
drl_.get(), kStepId);
cm_ = absl::make_unique<CancellationManager>();
}
~CollectiveRemoteAccessLocalTest() override = default;
@ -61,6 +62,7 @@ class CollectiveRemoteAccessLocalTest : public ::testing::Test {
std::unique_ptr<DeviceResolverLocal> drl_;
std::unique_ptr<CollectiveParamResolverLocal> prl_;
std::unique_ptr<CollectiveRemoteAccessLocal> rma_;
std::unique_ptr<CancellationManager> cm_;
};
TEST_F(CollectiveRemoteAccessLocalTest, PostRecvCPU0) {
@ -74,7 +76,7 @@ TEST_F(CollectiveRemoteAccessLocalTest, PostRecvCPU0) {
rma_->RecvFromPeer(kTaskName + "/device:CPU:0", kTaskName, true /*is_local*/,
"key_0", cpu0 /*to_device*/, nullptr /*to_device_ctx*/,
attr /*to_alloc_attr*/, &sink_tensor, dev_locality,
0 /*stream_index*/,
0 /*stream_index*/, cm_.get(),
[&recv_note, &recv_status](const Status& s) {
recv_status = s;
recv_note.Notify();
@ -90,7 +92,7 @@ TEST_F(CollectiveRemoteAccessLocalTest, PostRecvCPU0) {
rma_->PostToPeer(kTaskName + "/device:CPU:0", kTaskName, "key_0",
cpu0 /*from_device*/, nullptr /*from_device_ctx*/,
attr /*to_alloc_attr*/, &source_tensor, dev_locality,
[&send_note, &send_status](const Status& s) {
cm_.get(), [&send_note, &send_status](const Status& s) {
send_status = s;
send_note.Notify();
});
@ -117,7 +119,7 @@ TEST_F(CollectiveRemoteAccessLocalTest, PostRecvCPU1_2) {
rma_->RecvFromPeer(kTaskName + "/device:CPU:1", kTaskName, true /*is_local*/,
"key_0", cpu2 /*to_device*/, nullptr /*to_device_ctx*/,
attr /*to_alloc_attr*/, &sink_tensor, dev_locality,
0 /*stream_index*/,
0 /*stream_index*/, cm_.get(),
[&recv_note, &recv_status](const Status& s) {
recv_status = s;
recv_note.Notify();
@ -135,7 +137,7 @@ TEST_F(CollectiveRemoteAccessLocalTest, PostRecvCPU1_2) {
rma_->PostToPeer(kTaskName + "/device:CPU:2", kTaskName, "key_0",
cpu1 /*from_device*/, nullptr /*from_device_ctx*/,
attr /*to_alloc_attr*/, &source_tensor, dev_locality,
[&send_note, &send_status](const Status& s) {
cm_.get(), [&send_note, &send_status](const Status& s) {
send_status = s;
send_note.Notify();
});
@ -162,5 +164,91 @@ TEST_F(CollectiveRemoteAccessLocalTest, CheckHealth) {
EXPECT_TRUE(errors::IsInternal(status));
}
TEST_F(CollectiveRemoteAccessLocalTest, RecvThenCancel) {
Device* cpu0 = nullptr;
AllocatorAttributes attr;
DeviceLocality dev_locality;
TF_ASSERT_OK(device_mgr_->LookupDevice(kTaskName + "/device:CPU:0", &cpu0));
Tensor sink_tensor(DT_FLOAT, TensorShape({8}));
Notification recv_note;
Status recv_status;
rma_->RecvFromPeer(kTaskName + "/device:CPU:0", kTaskName, true /*is_local*/,
"key_0", cpu0 /*to_device*/, nullptr /*to_device_ctx*/,
attr /*to_alloc_attr*/, &sink_tensor, dev_locality,
0 /*stream_index*/, cm_.get(),
[&recv_note, &recv_status](const Status& s) {
recv_status = s;
recv_note.Notify();
});
cm_->StartCancel();
recv_note.WaitForNotification();
EXPECT_TRUE(cm_->IsCancelled());
EXPECT_TRUE(errors::IsCancelled(recv_status));
}
TEST_F(CollectiveRemoteAccessLocalTest, CancelThenRecv) {
Device* cpu0 = nullptr;
AllocatorAttributes attr;
DeviceLocality dev_locality;
TF_ASSERT_OK(device_mgr_->LookupDevice(kTaskName + "/device:CPU:0", &cpu0));
Tensor sink_tensor(DT_FLOAT, TensorShape({8}));
Notification recv_note;
Status recv_status;
cm_->StartCancel();
rma_->RecvFromPeer(kTaskName + "/device:CPU:0", kTaskName, true /*is_local*/,
"key_0", cpu0 /*to_device*/, nullptr /*to_device_ctx*/,
attr /*to_alloc_attr*/, &sink_tensor, dev_locality,
0 /*stream_index*/, cm_.get(),
[&recv_note, &recv_status](const Status& s) {
recv_status = s;
recv_note.Notify();
});
recv_note.WaitForNotification();
EXPECT_TRUE(cm_->IsCancelled());
EXPECT_TRUE(errors::IsCancelled(recv_status));
}
TEST_F(CollectiveRemoteAccessLocalTest, PostThenCancel) {
Device* cpu0 = nullptr;
AllocatorAttributes attr;
DeviceLocality dev_locality;
TF_ASSERT_OK(device_mgr_->LookupDevice(kTaskName + "/device:CPU:0", &cpu0));
Tensor source_tensor(DT_FLOAT, TensorShape({8}));
Notification send_note;
Status send_status;
rma_->PostToPeer(kTaskName + "/device:CPU:0", kTaskName, "key_0",
cpu0 /*from_device*/, nullptr /*from_device_ctx*/,
attr /*to_alloc_attr*/, &source_tensor, dev_locality,
cm_.get(), [&send_note, &send_status](const Status& s) {
send_status = s;
send_note.Notify();
});
cm_->StartCancel();
send_note.WaitForNotification();
EXPECT_TRUE(cm_->IsCancelled());
EXPECT_TRUE(errors::IsCancelled(send_status));
}
TEST_F(CollectiveRemoteAccessLocalTest, CancelThenPost) {
Device* cpu0 = nullptr;
AllocatorAttributes attr;
DeviceLocality dev_locality;
TF_ASSERT_OK(device_mgr_->LookupDevice(kTaskName + "/device:CPU:0", &cpu0));
Tensor source_tensor(DT_FLOAT, TensorShape({8}));
Notification send_note;
Status send_status;
cm_->StartCancel();
rma_->PostToPeer(kTaskName + "/device:CPU:0", kTaskName, "key_0",
cpu0 /*from_device*/, nullptr /*from_device_ctx*/,
attr /*to_alloc_attr*/, &source_tensor, dev_locality,
cm_.get(), [&send_note, &send_status](const Status& s) {
send_status = s;
send_note.Notify();
});
send_note.WaitForNotification();
EXPECT_TRUE(cm_->IsCancelled());
EXPECT_TRUE(errors::IsCancelled(send_status));
}
} // namespace
} // namespace tensorflow

View File

@ -423,7 +423,8 @@ void HierarchicalTreeBroadcaster::DispatchSend(int subdiv, int dst_rank,
col_params_->group.task_names[dst_idx], send_buf_key, col_ctx_->device,
col_ctx_->op_ctx->op_device_context(),
col_ctx_->op_ctx->output_alloc_attr(0), src_tensor,
col_ctx_->device_locality, done);
col_ctx_->device_locality, col_ctx_->op_ctx->cancellation_manager(),
done);
}
void HierarchicalTreeBroadcaster::DispatchRecv(int subdiv, int src_rank,
@ -443,7 +444,8 @@ void HierarchicalTreeBroadcaster::DispatchRecv(int subdiv, int src_rank,
col_params_->task.is_local[src_idx], recv_buf_key, col_ctx_->device,
col_ctx_->op_ctx->op_device_context(),
col_ctx_->op_ctx->output_alloc_attr(0), dst_tensor,
col_ctx_->device_locality, 0 /*stream_index*/, done);
col_ctx_->device_locality, 0 /*stream_index*/,
col_ctx_->op_ctx->cancellation_manager(), done);
}
namespace {

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/common_runtime/test_collective_executor_mgr.h"
#include "tensorflow/core/common_runtime/threadpool_device.h"
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/node_def.pb.h"
@ -165,11 +166,13 @@ class FailTestRMA : public CollectiveRemoteAccessLocal {
DeviceContext* to_device_ctx,
const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor,
const DeviceLocality& client_locality, int stream_index,
CancellationManager* cancellation_manager,
const StatusCallback& done) override {
if (MaybeFail(done)) return;
CollectiveRemoteAccessLocal::RecvFromPeer(
peer_device, peer_task, peer_is_local, key, to_device, to_device_ctx,
to_alloc_attr, to_tensor, client_locality, stream_index, done);
to_alloc_attr, to_tensor, client_locality, stream_index,
cancellation_manager, done);
}
void PostToPeer(const string& peer_device, const string& peer_task,
@ -178,11 +181,13 @@ class FailTestRMA : public CollectiveRemoteAccessLocal {
const AllocatorAttributes& from_alloc_attr,
const Tensor* from_tensor,
const DeviceLocality& client_locality,
CancellationManager* cancellation_manager,
const StatusCallback& done) override {
if (MaybeFail(done)) return;
CollectiveRemoteAccessLocal::PostToPeer(
peer_device, peer_task, key, from_device, from_device_ctx,
from_alloc_attr, from_tensor, client_locality, done);
from_alloc_attr, from_tensor, client_locality, cancellation_manager,
done);
}
mutex mu_;
@ -618,6 +623,7 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
OpKernelContext::Params op_params;
op_params.step_id = parent_->step_id_;
op_params.device = device_;
op_params.cancellation_manager = &parent_->cancellation_manager_;
gtl::InlinedVector<TensorValue, 4> inputs;
inputs.push_back(TensorValue(&tensor_));
op_params.inputs = &inputs;
@ -710,6 +716,7 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
int bcast_recv_counter_ TF_GUARDED_BY(mu_) = 0;
int bcast_send_counter_ TF_GUARDED_BY(mu_) = 0;
int failure_count_ TF_GUARDED_BY(mu_) = 0;
CancellationManager cancellation_manager_;
};
TEST_F(HierarchicalTreeBroadcasterTest, InitializeParams1Task8GPU) {

View File

@ -90,7 +90,7 @@ void Permuter::DispatchSend(int src_rank, int target_rank, const Tensor* tensor,
col_params_->group.task_names[target_rank], send_buf_key,
col_ctx_->device, col_ctx_->op_ctx->op_device_context(),
col_ctx_->op_ctx->output_alloc_attr(0), tensor, col_ctx_->device_locality,
done);
col_ctx_->op_ctx->cancellation_manager(), done);
}
void Permuter::DispatchRecv(int src_rank, int target_rank, Tensor* tensor,
@ -107,7 +107,7 @@ void Permuter::DispatchRecv(int src_rank, int target_rank, Tensor* tensor,
col_params_->task.is_local[src_rank], recv_buf_key, col_ctx_->device,
col_ctx_->op_ctx->op_device_context(),
col_ctx_->op_ctx->output_alloc_attr(0), tensor, col_ctx_->device_locality,
0, done);
0, col_ctx_->op_ctx->cancellation_manager(), done);
}
namespace {
REGISTER_COLLECTIVE(Permute, Permuter);

View File

@ -77,11 +77,13 @@ class FailTestRMA : public CollectiveRemoteAccessLocal {
DeviceContext* to_device_ctx,
const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor,
const DeviceLocality& client_locality, int stream_index,
CancellationManager* cancellation_manager,
const StatusCallback& done) override {
if (MaybeFail(done)) return;
CollectiveRemoteAccessLocal::RecvFromPeer(
peer_device, peer_task, peer_is_local, key, to_device, to_device_ctx,
to_alloc_attr, to_tensor, client_locality, stream_index, done);
to_alloc_attr, to_tensor, client_locality, stream_index,
cancellation_manager, done);
}
void PostToPeer(const string& peer_device, const string& peer_task,
@ -90,11 +92,13 @@ class FailTestRMA : public CollectiveRemoteAccessLocal {
const AllocatorAttributes& from_alloc_attr,
const Tensor* from_tensor,
const DeviceLocality& client_locality,
CancellationManager* cancellation_manager,
const StatusCallback& done) override {
if (MaybeFail(done)) return;
CollectiveRemoteAccessLocal::PostToPeer(
peer_device, peer_task, key, from_device, from_device_ctx,
from_alloc_attr, from_tensor, client_locality, done);
from_alloc_attr, from_tensor, client_locality, cancellation_manager,
done);
}
mutex mu_;
@ -361,6 +365,7 @@ class PermuterTest : public ::testing::Test {
OpKernelContext::Params op_params;
op_params.step_id = parent_->step_id_;
op_params.device = device_;
op_params.cancellation_manager = &parent_->cancellation_manager_;
gtl::InlinedVector<TensorValue, 4> inputs;
inputs.push_back(TensorValue(&tensor_input_));
op_params.inputs = &inputs;
@ -427,6 +432,7 @@ class PermuterTest : public ::testing::Test {
mutex mu_;
int permute_counter_ TF_GUARDED_BY(mu_) = 0;
std::vector<int> permutation_;
CancellationManager cancellation_manager_;
};
// TODO(b/113171733): change to use TEST_P.

View File

@ -389,7 +389,8 @@ void RingAlg::DispatchSend(RingField* rf, const StatusCallback& done) {
col_params_->group.task_names[send_to_dev_idx], send_buf_key,
col_ctx_->device, col_ctx_->op_ctx->op_device_context(),
col_ctx_->op_ctx->output_alloc_attr(0), &rf->chunk,
col_ctx_->device_locality, done);
col_ctx_->device_locality, col_ctx_->op_ctx->cancellation_manager(),
done);
}
void RingAlg::DispatchRecv(RingField* rf, const StatusCallback& done) {
@ -409,7 +410,8 @@ void RingAlg::DispatchRecv(RingField* rf, const StatusCallback& done) {
col_params_->task.is_local[rf->recv_dev_idx], recv_buf_key,
col_ctx_->device, col_ctx_->op_ctx->op_device_context(),
col_ctx_->op_ctx->output_alloc_attr(0), dst_tensor,
col_ctx_->device_locality, rf->subdiv_idx, done);
col_ctx_->device_locality, rf->subdiv_idx,
col_ctx_->op_ctx->cancellation_manager(), done);
}
string RingAlg::FieldState() {

View File

@ -70,12 +70,13 @@ class FailTestRMA : public CollectiveRemoteAccessLocal {
const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor,
const DeviceLocality& client_locality,
int dev_to_dev_stream_index,
CancellationManager* cancellation_manager,
const StatusCallback& done) override {
if (MaybeFail(done)) return;
CollectiveRemoteAccessLocal::RecvFromPeer(
peer_device, peer_task, peer_is_local, key, to_device, to_device_ctx,
to_alloc_attr, to_tensor, client_locality, dev_to_dev_stream_index,
done);
cancellation_manager, done);
}
void PostToPeer(const string& peer_device, const string& peer_task,
@ -84,11 +85,13 @@ class FailTestRMA : public CollectiveRemoteAccessLocal {
const AllocatorAttributes& from_alloc_attr,
const Tensor* from_tensor,
const DeviceLocality& client_locality,
CancellationManager* cancellation_manager,
const StatusCallback& done) override {
if (MaybeFail(done)) return;
CollectiveRemoteAccessLocal::PostToPeer(
peer_device, peer_task, key, from_device, from_device_ctx,
from_alloc_attr, from_tensor, client_locality, done);
from_alloc_attr, from_tensor, client_locality, cancellation_manager,
done);
}
mutex mu_;
@ -442,6 +445,7 @@ class RingGathererTest : public ::testing::Test {
OpKernelContext::Params op_params;
op_params.step_id = kStepId;
op_params.device = device_;
op_params.cancellation_manager = &parent_->cancellation_manager_;
gtl::InlinedVector<TensorValue, 4> inputs;
inputs.push_back(TensorValue(&input_tensor_));
op_params.inputs = &inputs;
@ -523,6 +527,7 @@ class RingGathererTest : public ::testing::Test {
std::unique_ptr<string> gpu_ring_order_;
mutex mu_;
int32 gather_counter_ TF_GUARDED_BY(mu_) = 0;
CancellationManager cancellation_manager_;
};
CollectiveParams SetUpCollectiveParams(const int num_devs_per_task,

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/common_runtime/test_collective_executor_mgr.h"
#include "tensorflow/core/common_runtime/threadpool_device.h"
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/node_def.pb.h"
@ -70,12 +71,13 @@ class FailTestRMA : public CollectiveRemoteAccessLocal {
const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor,
const DeviceLocality& client_locality,
int dev_to_dev_stream_index,
CancellationManager* cancellation_manager,
const StatusCallback& done) override {
if (MaybeFail(done)) return;
CollectiveRemoteAccessLocal::RecvFromPeer(
peer_device, peer_task, peer_is_local, key, to_device, to_device_ctx,
to_alloc_attr, to_tensor, client_locality, dev_to_dev_stream_index,
done);
cancellation_manager, done);
}
void PostToPeer(const string& peer_device, const string& peer_task,
@ -84,11 +86,13 @@ class FailTestRMA : public CollectiveRemoteAccessLocal {
const AllocatorAttributes& from_alloc_attr,
const Tensor* from_tensor,
const DeviceLocality& client_locality,
CancellationManager* cancellation_manager,
const StatusCallback& done) override {
if (MaybeFail(done)) return;
CollectiveRemoteAccessLocal::PostToPeer(
peer_device, peer_task, key, from_device, from_device_ctx,
from_alloc_attr, from_tensor, client_locality, done);
from_alloc_attr, from_tensor, client_locality, cancellation_manager,
done);
}
mutex mu_;
@ -471,6 +475,7 @@ class RingReducerTest : public ::testing::Test {
OpKernelContext::Params op_params;
op_params.step_id = kStepId;
op_params.device = device_;
op_params.cancellation_manager = &parent_->cancellation_manager_;
gtl::InlinedVector<TensorValue, 4> inputs;
inputs.push_back(TensorValue(&tensor_));
op_params.inputs = &inputs;
@ -550,6 +555,7 @@ class RingReducerTest : public ::testing::Test {
std::unique_ptr<string> gpu_ring_order_;
mutex mu_;
int32 reduce_counter_ TF_GUARDED_BY(mu_) = 0;
CancellationManager cancellation_manager_;
};
CollectiveParams SetUpCollectiveParams(const int num_devs_per_task,

View File

@ -78,12 +78,12 @@ void CollectiveRemoteAccessDistributed::RecvFromPeer(
const string& key, Device* to_device, DeviceContext* to_device_ctx,
const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor,
const DeviceLocality& client_locality, int dev_to_dev_stream_index,
const StatusCallback& done) {
CancellationManager* cancellation_manager, const StatusCallback& done) {
if (peer_is_local) {
CollectiveRemoteAccessLocal::RecvFromPeer(
peer_device, peer_task, peer_is_local, key, to_device, to_device_ctx,
to_alloc_attr, to_tensor, client_locality, dev_to_dev_stream_index,
done);
cancellation_manager, done);
return;
}
@ -166,10 +166,15 @@ void CollectiveRemoteAccessDistributed::RecvFromPeer(
recv_buf_callback(s);
return;
}
state->call.reset(
new RecvBufCall(step_id_, peer_device, peer_task, key, to_device,
to_device_ctx, to_alloc_attr, to_tensor, client_locality,
state->server_attributes, &cancel_mgr_, worker_cache_));
// If a per-call `cancellation_manager` is passed to this function, prefer
// using that over `abortion_cancellation_manager_`. This is because abortion
// should also be accompanied by opkernel cancellation.
state->call.reset(new RecvBufCall(
step_id_, peer_device, peer_task, key, to_device, to_device_ctx,
to_alloc_attr, to_tensor, client_locality, state->server_attributes,
cancellation_manager == nullptr ? &abortion_cancellation_manager_
: cancellation_manager,
worker_cache_));
state->call->Start(recv_buf_callback);
}
@ -231,7 +236,7 @@ void CollectiveRemoteAccessDistributed::CheckPeerHealth(
void CollectiveRemoteAccessDistributed::StartAbort(const Status& s) {
CollectiveRemoteAccessLocal::StartAbort(s);
cancel_mgr_.StartCancel();
abortion_cancellation_manager_.StartCancel();
}
} // namespace tensorflow

View File

@ -42,6 +42,7 @@ class CollectiveRemoteAccessDistributed : public CollectiveRemoteAccessLocal {
const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor,
const DeviceLocality& client_locality,
int dev_to_dev_stream_index,
CancellationManager* cancellation_manager,
const StatusCallback& done) override;
void CheckPeerHealth(const string& peer_task,
@ -54,7 +55,7 @@ class CollectiveRemoteAccessDistributed : public CollectiveRemoteAccessLocal {
// Ownership of `work_queue_` is shared between `this` and
// `CollectiveExecutorMgr`.
std::shared_ptr<UnboundedWorkQueue> work_queue_;
CancellationManager cancel_mgr_;
CancellationManager abortion_cancellation_manager_;
string task_name_;
};

View File

@ -126,7 +126,8 @@ class FakeWorker : public TestWorkerInterface {
}
done(s);
if (h) BufRendezvous::DoneWithHook(h);
});
},
nullptr /*cancellation_manager*/);
}
private:
@ -311,7 +312,8 @@ TEST_F(CollRMADistTest, ProdFirstOK) {
[&producer_note, &producer_status](const Status& s) {
producer_status.Update(s);
producer_note.Notify();
});
},
nullptr /*cancellation_manager*/);
Device* dst_device = nullptr;
string dev_name = "CPU:0";
TF_EXPECT_OK(device_mgrs_[0]->LookupDevice(dev_name, &dst_device));
@ -322,6 +324,7 @@ TEST_F(CollRMADistTest, ProdFirstOK) {
false, // peer_is_local
kBufKey, dst_device, to_device_ctx, alloc_attr_, &to_tensor_,
device_locality_, 0 /*dev_to_dev_stream_index*/,
nullptr /*cancellation_manager*/,
[&consumer_status, &consumer_note](const Status& s) {
consumer_status = s;
consumer_note.Notify();
@ -351,6 +354,7 @@ TEST_F(CollRMADistTest, ConsFirstOK) {
false, // peer_is_local
kBufKey, dst_device, to_device_ctx, alloc_attr_, &to_tensor_,
device_locality_, 0 /*dev_to_dev_stream_index*/,
nullptr /*cancellation_manager*/,
[&consumer_status, &consumer_note](const Status& s) {
consumer_status = s;
consumer_note.Notify();
@ -361,7 +365,8 @@ TEST_F(CollRMADistTest, ConsFirstOK) {
[&producer_note, &producer_status](const Status& s) {
producer_status.Update(s);
producer_note.Notify();
});
},
nullptr /*cancellation_manager*/);
consumer_note.WaitForNotification();
TF_EXPECT_OK(consumer_status);
producer_note.WaitForNotification();
@ -384,6 +389,7 @@ TEST_F(CollRMADistTest, ConsFirstAbort) {
false, // peer_is_local
kBufKey, dst_device, to_device_ctx, alloc_attr_, &to_tensor_,
device_locality_, 0 /*dev_to_dev_stream_index*/,
nullptr /*cancellation_manager*/,
[&consumer_status, &consumer_note](const Status& s) {
consumer_status = s;
consumer_note.Notify();
@ -411,6 +417,7 @@ TEST_F(CollRMADistTest, WorkerRestart) {
false, // peer_is_local
buf_key, dst_device, to_device_ctx, alloc_attr_, &to_tensor_,
device_locality_, 0 /*dev_to_dev_stream_index*/,
nullptr /*cancellation_manager*/,
[&consumer_status, &consumer_note](const Status& s) {
consumer_status = s;
consumer_note.Notify();
@ -421,7 +428,8 @@ TEST_F(CollRMADistTest, WorkerRestart) {
[&producer_note, &producer_status](const Status& s) {
producer_status.Update(s);
producer_note.Notify();
});
},
nullptr /*cancellation_manager*/);
consumer_note.WaitForNotification();
TF_EXPECT_OK(consumer_status);
producer_note.WaitForNotification();
@ -437,6 +445,7 @@ TEST_F(CollRMADistTest, WorkerRestart) {
false, // peer_is_local
buf_key, dst_device, to_device_ctx, alloc_attr_, &to_tensor_,
device_locality_, 0 /*dev_to_dev_stream_index*/,
nullptr /*cancellation_manager*/,
[&consumer_status, &post_restart_note](const Status& s) {
consumer_status = s;
post_restart_note.Notify();

View File

@ -696,7 +696,8 @@ void GrpcWorker::RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
};
rma->buf_rendezvous()->ConsumeBuf(
request->buf_rendezvous_key(), request->src_device(),
request->src_incarnation(), consumer_callback);
request->src_incarnation(), consumer_callback,
/*cancellation_manager=*/nullptr);
}
void GrpcWorker::LoggingAsync(const LoggingRequest* request,

View File

@ -268,6 +268,7 @@ class CollectiveRemoteAccess {
Tensor* to_tensor,
const DeviceLocality& client_locality,
int dev_to_dev_stream_index,
CancellationManager* cancellation_manager,
const StatusCallback& done) = 0;
virtual void PostToPeer(const string& peer_device, const string& peer_task,
@ -276,6 +277,7 @@ class CollectiveRemoteAccess {
const AllocatorAttributes& from_alloc_attr,
const Tensor* from_tensor,
const DeviceLocality& client_locality,
CancellationManager* cancellation_manager,
const StatusCallback& done) = 0;
// Checks the health of a collective peer. It probes the peer to see if it is