From 55119aadc69d394047e5f75d514fb6488cd4adb4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 12 Dec 2019 23:11:58 -0800 Subject: [PATCH] Add CancellationManager::IsCancelling() method. PiperOrigin-RevId: 285343578 Change-Id: Idc1d293d5567cb9e12f43fcde00d219da111b8d8 --- tensorflow/core/framework/cancellation.cc | 5 ++++ tensorflow/core/framework/cancellation.h | 3 +++ .../core/framework/cancellation_test.cc | 25 +++++++++++++++++++ 3 files changed, 33 insertions(+) diff --git a/tensorflow/core/framework/cancellation.cc b/tensorflow/core/framework/cancellation.cc index a91442fcbad..99ac9a70ac1 100644 --- a/tensorflow/core/framework/cancellation.cc +++ b/tensorflow/core/framework/cancellation.cc @@ -201,4 +201,9 @@ CancellationManager::~CancellationManager() { } } +bool CancellationManager::IsCancelling() { + mutex_lock lock(mu_); + return is_cancelling_; +} + } // end namespace tensorflow diff --git a/tensorflow/core/framework/cancellation.h b/tensorflow/core/framework/cancellation.h index 3e1727ae54a..7e60eb54065 100644 --- a/tensorflow/core/framework/cancellation.h +++ b/tensorflow/core/framework/cancellation.h @@ -143,6 +143,9 @@ class CancellationManager { // called. bool TryDeregisterCallback(CancellationToken token); + // Returns true iff cancellation is in progress. + bool IsCancelling(); + private: struct State { Notification cancelled_notification; diff --git a/tensorflow/core/framework/cancellation_test.cc b/tensorflow/core/framework/cancellation_test.cc index e4994350ddd..743b3905b83 100644 --- a/tensorflow/core/framework/cancellation_test.cc +++ b/tensorflow/core/framework/cancellation_test.cc @@ -108,6 +108,7 @@ TEST(Cancellation, IsCancelled) { w.Schedule([n, cm]() { while (!cm->IsCancelled()) { } + ASSERT_FALSE(cm->IsCancelling()); n->Notify(); }); } @@ -119,6 +120,30 @@ TEST(Cancellation, IsCancelled) { delete cm; } +TEST(Cancellation, IsCancelling) { + CancellationManager cm; + Notification started_cancelling; + Notification can_finish_cancel; + Notification cancel_done; + thread::ThreadPool w(Env::Default(), "test", 1); + auto token = cm.get_cancellation_token(); + ASSERT_TRUE( + cm.RegisterCallback(token, [&started_cancelling, &can_finish_cancel]() { + started_cancelling.Notify(); + can_finish_cancel.WaitForNotification(); + })); + w.Schedule([&cm, &cancel_done]() { + cm.StartCancel(); + cancel_done.Notify(); + }); + started_cancelling.WaitForNotification(); + ASSERT_TRUE(cm.IsCancelling()); + can_finish_cancel.Notify(); + cancel_done.WaitForNotification(); + ASSERT_FALSE(cm.IsCancelling()); + ASSERT_TRUE(cm.IsCancelled()); +} + TEST(Cancellation, TryDeregisterWithoutCancel) { bool is_cancelled = false; CancellationManager* manager = new CancellationManager();