From 745028b1fba059062128647ea6dfa8cd8021dfea Mon Sep 17 00:00:00 2001 From: Alexandre Passos Date: Fri, 15 Feb 2019 14:56:52 -0800 Subject: [PATCH] Fixes sticky cancellation error in tf.function Fixes #25382 PiperOrigin-RevId: 234218637 --- .../core/common_runtime/eager/kernel_and_device.cc | 2 ++ tensorflow/core/framework/cancellation.cc | 6 ++++++ tensorflow/core/framework/cancellation.h | 3 +++ tensorflow/python/eager/def_function_test.py | 14 ++++++++++++++ 4 files changed, 25 insertions(+) diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.cc b/tensorflow/core/common_runtime/eager/kernel_and_device.cc index 41b4608c7e7..91da8bb96f9 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device.cc +++ b/tensorflow/core/common_runtime/eager/kernel_and_device.cc @@ -194,6 +194,7 @@ Status KernelAndDeviceOp::Run(ScopedStepContainer* step_container, params.slice_reader_cache = &slice_reader_cache_; params.rendezvous = rendez_; params.cancellation_manager = &cm_; + cm_.Reset(); params.log_memory = log_memory_; std::unique_ptr step_stats_collector; if (stats != nullptr) { @@ -258,6 +259,7 @@ Status KernelAndDeviceFunc::Run( opts.rendezvous = nullptr; opts.create_rendezvous = true; opts.cancellation_manager = &cm_; + cm_.Reset(); // eager runtime does not yet support collective ops. opts.collective_executor = nullptr; opts.allow_dead_tensors = true; diff --git a/tensorflow/core/framework/cancellation.cc b/tensorflow/core/framework/cancellation.cc index af59500aee3..7f639b5ca9a 100644 --- a/tensorflow/core/framework/cancellation.cc +++ b/tensorflow/core/framework/cancellation.cc @@ -27,6 +27,12 @@ CancellationManager::CancellationManager() is_cancelled_(false), next_cancellation_token_(0) {} +void CancellationManager::Reset() { + mutex_lock l(mu_); + is_cancelling_ = false; + is_cancelled_.store(false); +} + void CancellationManager::StartCancel() { gtl::FlatMap callbacks_to_run; { diff --git a/tensorflow/core/framework/cancellation.h b/tensorflow/core/framework/cancellation.h index 7a5d9424867..51b200423ec 100644 --- a/tensorflow/core/framework/cancellation.h +++ b/tensorflow/core/framework/cancellation.h @@ -56,6 +56,9 @@ class CancellationManager { // Returns true iff StartCancel() has been called. bool IsCancelled() { return is_cancelled_.load(std::memory_order_acquire); } + // Resets the cancellation manager to its original pre-cancelled state. + void Reset(); + // Returns a token that must be used in calls to RegisterCallback // and DeregisterCallback. CancellationToken get_cancellation_token(); diff --git a/tensorflow/python/eager/def_function_test.py b/tensorflow/python/eager/def_function_test.py index 3d107d59524..ad4564c19fd 100644 --- a/tensorflow/python/eager/def_function_test.py +++ b/tensorflow/python/eager/def_function_test.py @@ -26,12 +26,14 @@ from tensorflow.python.eager import def_function from tensorflow.python.eager import lift_to_graph from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util from tensorflow.python.keras.engine import training from tensorflow.python.keras.layers import core from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import resource_variable_ops @@ -286,6 +288,18 @@ class DefFunctionTest(test.TestCase): with self.assertRaisesRegexp(ValueError, 'inner'): f(array_ops.zeros(shape=(8, 42, 3))) + def testRuntimeErrorNotSticky(self): + + @def_function.function + def fail(i): + control_flow_ops.Assert(math_ops.equal(i, 0), ['ick']) + + fail(constant_op.constant(0)) # OK + with self.assertRaises(errors.InvalidArgumentError): + fail(constant_op.constant(1)) # InvalidArgument: "ick" + fail(constant_op.constant(0)) # OK + + def test_serialization_signature_cache(self): @def_function.function