From 987046e078246a1b80a88e6dc1e9158121f452f0 Mon Sep 17 00:00:00 2001
From: Derek Murray <mrry@google.com>
Date: Wed, 3 Jul 2019 08:06:30 -0700
Subject: [PATCH] Add a SWIG wrapper for the `tensorflow::CancellationManager`
 class.

This change is a step towards supporting user-driven cancellation for eager function calls. In a future change, I plan to add an experimental method for calling a `tf.function` and passing a `CancellationManager` argument, so that the caller can cancel execution asynchronously.

PiperOrigin-RevId: 256369003
---
 tensorflow/c/eager/c_api_experimental.cc      | 19 +++++++++
 tensorflow/c/eager/c_api_experimental.h       | 13 ++++++
 tensorflow/c/eager/c_api_experimental_test.cc |  8 ++++
 tensorflow/c/eager/c_api_internal.h           |  5 +++
 tensorflow/python/eager/BUILD                 | 21 ++++++++++
 tensorflow/python/eager/cancellation.py       | 40 +++++++++++++++++++
 tensorflow/python/eager/cancellation_test.py  | 34 ++++++++++++++++
 tensorflow/python/pywrap_tfe.i                |  4 ++
 8 files changed, 144 insertions(+)
 create mode 100644 tensorflow/python/eager/cancellation.py
 create mode 100644 tensorflow/python/eager/cancellation_test.py

diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc
index 265f91b0831..975419ca8a0 100644
--- a/tensorflow/c/eager/c_api_experimental.cc
+++ b/tensorflow/c/eager/c_api_experimental.cc
@@ -541,3 +541,22 @@ TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler2(
   return static_cast<TFE_MonitoringSamplerCell*>(
       static_cast<void*>(sampler->sampler->GetCell(label1, label2)));
 }
+
+TFE_CancellationManager* TFE_NewCancellationManager() {
+  return new TFE_CancellationManager;
+}
+
+void TFE_CancellationManagerStartCancel(
+    TFE_CancellationManager* cancellation_manager) {
+  cancellation_manager->cancellation_manager.StartCancel();
+}
+
+bool TFE_CancellationManagerIsCancelled(
+    TFE_CancellationManager* cancellation_manager) {
+  return cancellation_manager->cancellation_manager.IsCancelled();
+}
+
+void TFE_DeleteCancellationManager(
+    TFE_CancellationManager* cancellation_manager) {
+  delete cancellation_manager;
+}
diff --git a/tensorflow/c/eager/c_api_experimental.h b/tensorflow/c/eager/c_api_experimental.h
index 029a7873681..50ab0a7f09f 100644
--- a/tensorflow/c/eager/c_api_experimental.h
+++ b/tensorflow/c/eager/c_api_experimental.h
@@ -343,6 +343,19 @@ TF_CAPI_EXPORT extern void TFE_ContextSetThreadLocalMirroringPolicy(
 // thread.
 TF_CAPI_EXPORT extern TFE_ContextMirroringPolicy TFE_ContextGetMirroringPolicy(
     TFE_Context*);
+
+// -----------------------------------------------------------------------------
+// Cancellation APIs.
+
+typedef struct TFE_CancellationManager TFE_CancellationManager;
+TF_CAPI_EXPORT extern TFE_CancellationManager* TFE_NewCancellationManager();
+TF_CAPI_EXPORT extern bool TFE_CancellationManagerIsCancelled(
+    TFE_CancellationManager*);
+TF_CAPI_EXPORT extern void TFE_CancellationManagerStartCancel(
+    TFE_CancellationManager*);
+TF_CAPI_EXPORT extern void TFE_DeleteCancellationManager(
+    TFE_CancellationManager*);
+
 #ifdef __cplusplus
 } /* end extern "C" */
 #endif
diff --git a/tensorflow/c/eager/c_api_experimental_test.cc b/tensorflow/c/eager/c_api_experimental_test.cc
index b1f08c202a6..f4b02fffb4b 100644
--- a/tensorflow/c/eager/c_api_experimental_test.cc
+++ b/tensorflow/c/eager/c_api_experimental_test.cc
@@ -295,5 +295,13 @@ TEST(CAPI, MonitoringMultipleSampler) {
   TF_DeleteStatus(status);
 }
 
+TEST(CAPI, CancellationManager) {
+  TFE_CancellationManager* c_mgr = TFE_NewCancellationManager();
+  EXPECT_FALSE(TFE_CancellationManagerIsCancelled(c_mgr));
+  TFE_CancellationManagerStartCancel(c_mgr);
+  EXPECT_TRUE(TFE_CancellationManagerIsCancelled(c_mgr));
+  TFE_DeleteCancellationManager(c_mgr);
+}
+
 }  // namespace
 }  // namespace tensorflow
diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h
index 9fc55a0108e..1bdfcc23fc9 100644
--- a/tensorflow/c/eager/c_api_internal.h
+++ b/tensorflow/c/eager/c_api_internal.h
@@ -36,6 +36,7 @@ limitations under the License.
 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
 #include "tensorflow/core/common_runtime/function.h"
 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
+#include "tensorflow/core/framework/cancellation.h"
 #include "tensorflow/core/framework/rendezvous.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/core/stringpiece.h"
@@ -295,4 +296,8 @@ struct TFE_TraceContext {
   }
 };
 
+struct TFE_CancellationManager {
+  tensorflow::CancellationManager cancellation_manager;
+};
+
 #endif  // TENSORFLOW_C_EAGER_C_API_INTERNAL_H_
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index e431e4e156c..bdc6920e6d6 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -84,6 +84,27 @@ py_library(
     ],
 )
 
+py_library(
+    name = "cancellation",
+    srcs = ["cancellation.py"],
+    srcs_version = "PY2AND3",
+    visibility = ["//tensorflow:internal"],
+    deps = [
+        "//tensorflow/python:pywrap_tensorflow",
+    ],
+)
+
+cuda_py_test(
+    name = "cancellation_test",
+    size = "small",
+    srcs = ["cancellation_test.py"],
+    additional_deps = [
+        ":cancellation",
+        ":test",
+    ],
+    tags = ["no_pip"],
+)
+
 py_library(
     name = "context",
     srcs = ["context.py"],
diff --git a/tensorflow/python/eager/cancellation.py b/tensorflow/python/eager/cancellation.py
new file mode 100644
index 00000000000..72e435ef41c
--- /dev/null
+++ b/tensorflow/python/eager/cancellation.py
@@ -0,0 +1,40 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Cancellation support for eager execution."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python import pywrap_tensorflow
+
+
+class CancellationManager(object):
+  """A mechanism for cancelling blocking computation."""
+
+  def __init__(self):
+    self._impl = pywrap_tensorflow.TFE_NewCancellationManager()
+
+  @property
+  def is_cancelled(self):
+    """Returns `True` if `CancellationManager.start_cancel` has been called."""
+    return pywrap_tensorflow.TFE_CancellationManagerIsCancelled(self._impl)
+
+  def start_cancel(self):
+    """Cancels blocking operations that have been registered with this object."""
+    pywrap_tensorflow.TFE_CancellationManagerStartCancel(self._impl)
+
+  def __del__(self):
+    pywrap_tensorflow.TFE_DeleteCancellationManager(self._impl)
diff --git a/tensorflow/python/eager/cancellation_test.py b/tensorflow/python/eager/cancellation_test.py
new file mode 100644
index 00000000000..a5413f2f468
--- /dev/null
+++ b/tensorflow/python/eager/cancellation_test.py
@@ -0,0 +1,34 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.eager import cancellation
+from tensorflow.python.platform import test
+
+
+class CancellationTest(test.TestCase):
+
+  def testStartCancel(self):
+    manager = cancellation.CancellationManager()
+    self.assertFalse(manager.is_cancelled)
+    manager.start_cancel()
+    self.assertTrue(manager.is_cancelled)
+
+
+if __name__ == '__main__':
+  test.main()
diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i
index 6f68a2c0548..9d53d23812e 100755
--- a/tensorflow/python/pywrap_tfe.i
+++ b/tensorflow/python/pywrap_tfe.i
@@ -159,6 +159,10 @@ limitations under the License.
 %rename("%s") TFE_MonitoringNewSampler2;
 %rename("%s") TFE_MonitoringDeleteSampler2;
 %rename("%s") TFE_MonitoringGetCellSampler2;
+%rename("%s") TFE_NewCancellationManager;
+%rename("%s") TFE_CancellationManagerIsCancelled;
+%rename("%s") TFE_CancellationManagerStartCancel;
+%rename("%s") TFE_DeleteCancellationManager;
 
 %{
 #include "tensorflow/python/eager/pywrap_tfe.h"