diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc index 265f91b0831..975419ca8a0 100644 --- a/tensorflow/c/eager/c_api_experimental.cc +++ b/tensorflow/c/eager/c_api_experimental.cc @@ -541,3 +541,22 @@ TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler2( return static_cast( static_cast(sampler->sampler->GetCell(label1, label2))); } + +TFE_CancellationManager* TFE_NewCancellationManager() { + return new TFE_CancellationManager; +} + +void TFE_CancellationManagerStartCancel( + TFE_CancellationManager* cancellation_manager) { + cancellation_manager->cancellation_manager.StartCancel(); +} + +bool TFE_CancellationManagerIsCancelled( + TFE_CancellationManager* cancellation_manager) { + return cancellation_manager->cancellation_manager.IsCancelled(); +} + +void TFE_DeleteCancellationManager( + TFE_CancellationManager* cancellation_manager) { + delete cancellation_manager; +} diff --git a/tensorflow/c/eager/c_api_experimental.h b/tensorflow/c/eager/c_api_experimental.h index 029a7873681..50ab0a7f09f 100644 --- a/tensorflow/c/eager/c_api_experimental.h +++ b/tensorflow/c/eager/c_api_experimental.h @@ -343,6 +343,19 @@ TF_CAPI_EXPORT extern void TFE_ContextSetThreadLocalMirroringPolicy( // thread. TF_CAPI_EXPORT extern TFE_ContextMirroringPolicy TFE_ContextGetMirroringPolicy( TFE_Context*); + +// ----------------------------------------------------------------------------- +// Cancellation APIs. + +typedef struct TFE_CancellationManager TFE_CancellationManager; +TF_CAPI_EXPORT extern TFE_CancellationManager* TFE_NewCancellationManager(); +TF_CAPI_EXPORT extern bool TFE_CancellationManagerIsCancelled( + TFE_CancellationManager*); +TF_CAPI_EXPORT extern void TFE_CancellationManagerStartCancel( + TFE_CancellationManager*); +TF_CAPI_EXPORT extern void TFE_DeleteCancellationManager( + TFE_CancellationManager*); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/eager/c_api_experimental_test.cc b/tensorflow/c/eager/c_api_experimental_test.cc index b1f08c202a6..f4b02fffb4b 100644 --- a/tensorflow/c/eager/c_api_experimental_test.cc +++ b/tensorflow/c/eager/c_api_experimental_test.cc @@ -295,5 +295,13 @@ TEST(CAPI, MonitoringMultipleSampler) { TF_DeleteStatus(status); } +TEST(CAPI, CancellationManager) { + TFE_CancellationManager* c_mgr = TFE_NewCancellationManager(); + EXPECT_FALSE(TFE_CancellationManagerIsCancelled(c_mgr)); + TFE_CancellationManagerStartCancel(c_mgr); + EXPECT_TRUE(TFE_CancellationManagerIsCancelled(c_mgr)); + TFE_DeleteCancellationManager(c_mgr); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index 9fc55a0108e..1bdfcc23fc9 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -36,6 +36,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/eager/tensor_handle.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" +#include "tensorflow/core/framework/cancellation.h" #include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" @@ -295,4 +296,8 @@ struct TFE_TraceContext { } }; +struct TFE_CancellationManager { + tensorflow::CancellationManager cancellation_manager; +}; + #endif // TENSORFLOW_C_EAGER_C_API_INTERNAL_H_ diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index e431e4e156c..bdc6920e6d6 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -84,6 +84,27 @@ py_library( ], ) +py_library( + name = "cancellation", + srcs = ["cancellation.py"], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/python:pywrap_tensorflow", + ], +) + +cuda_py_test( + name = "cancellation_test", + size = "small", + srcs = ["cancellation_test.py"], + additional_deps = [ + ":cancellation", + ":test", + ], + tags = ["no_pip"], +) + py_library( name = "context", srcs = ["context.py"], diff --git a/tensorflow/python/eager/cancellation.py b/tensorflow/python/eager/cancellation.py new file mode 100644 index 00000000000..72e435ef41c --- /dev/null +++ b/tensorflow/python/eager/cancellation.py @@ -0,0 +1,40 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Cancellation support for eager execution.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python import pywrap_tensorflow + + +class CancellationManager(object): + """A mechanism for cancelling blocking computation.""" + + def __init__(self): + self._impl = pywrap_tensorflow.TFE_NewCancellationManager() + + @property + def is_cancelled(self): + """Returns `True` if `CancellationManager.start_cancel` has been called.""" + return pywrap_tensorflow.TFE_CancellationManagerIsCancelled(self._impl) + + def start_cancel(self): + """Cancels blocking operations that have been registered with this object.""" + pywrap_tensorflow.TFE_CancellationManagerStartCancel(self._impl) + + def __del__(self): + pywrap_tensorflow.TFE_DeleteCancellationManager(self._impl) diff --git a/tensorflow/python/eager/cancellation_test.py b/tensorflow/python/eager/cancellation_test.py new file mode 100644 index 00000000000..a5413f2f468 --- /dev/null +++ b/tensorflow/python/eager/cancellation_test.py @@ -0,0 +1,34 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.eager import cancellation +from tensorflow.python.platform import test + + +class CancellationTest(test.TestCase): + + def testStartCancel(self): + manager = cancellation.CancellationManager() + self.assertFalse(manager.is_cancelled) + manager.start_cancel() + self.assertTrue(manager.is_cancelled) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i index 6f68a2c0548..9d53d23812e 100755 --- a/tensorflow/python/pywrap_tfe.i +++ b/tensorflow/python/pywrap_tfe.i @@ -159,6 +159,10 @@ limitations under the License. %rename("%s") TFE_MonitoringNewSampler2; %rename("%s") TFE_MonitoringDeleteSampler2; %rename("%s") TFE_MonitoringGetCellSampler2; +%rename("%s") TFE_NewCancellationManager; +%rename("%s") TFE_CancellationManagerIsCancelled; +%rename("%s") TFE_CancellationManagerStartCancel; +%rename("%s") TFE_DeleteCancellationManager; %{ #include "tensorflow/python/eager/pywrap_tfe.h"