Add CancellationManager::IsCancelling() method.

PiperOrigin-RevId: 285343578
Change-Id: Idc1d293d5567cb9e12f43fcde00d219da111b8d8
This commit is contained in:
A. Unique TensorFlower 2019-12-12 23:11:58 -08:00 committed by TensorFlower Gardener
parent d1574c2093
commit 55119aadc6
3 changed files with 33 additions and 0 deletions

View File

@ -201,4 +201,9 @@ CancellationManager::~CancellationManager() {
}
}
bool CancellationManager::IsCancelling() {
mutex_lock lock(mu_);
return is_cancelling_;
}
} // end namespace tensorflow

View File

@ -143,6 +143,9 @@ class CancellationManager {
// called.
bool TryDeregisterCallback(CancellationToken token);
// Returns true iff cancellation is in progress.
bool IsCancelling();
private:
struct State {
Notification cancelled_notification;

View File

@ -108,6 +108,7 @@ TEST(Cancellation, IsCancelled) {
w.Schedule([n, cm]() {
while (!cm->IsCancelled()) {
}
ASSERT_FALSE(cm->IsCancelling());
n->Notify();
});
}
@ -119,6 +120,30 @@ TEST(Cancellation, IsCancelled) {
delete cm;
}
TEST(Cancellation, IsCancelling) {
CancellationManager cm;
Notification started_cancelling;
Notification can_finish_cancel;
Notification cancel_done;
thread::ThreadPool w(Env::Default(), "test", 1);
auto token = cm.get_cancellation_token();
ASSERT_TRUE(
cm.RegisterCallback(token, [&started_cancelling, &can_finish_cancel]() {
started_cancelling.Notify();
can_finish_cancel.WaitForNotification();
}));
w.Schedule([&cm, &cancel_done]() {
cm.StartCancel();
cancel_done.Notify();
});
started_cancelling.WaitForNotification();
ASSERT_TRUE(cm.IsCancelling());
can_finish_cancel.Notify();
cancel_done.WaitForNotification();
ASSERT_FALSE(cm.IsCancelling());
ASSERT_TRUE(cm.IsCancelled());
}
TEST(Cancellation, TryDeregisterWithoutCancel) {
bool is_cancelled = false;
CancellationManager* manager = new CancellationManager();