Fixes sticky cancellation error in tf.function
Fixes #25382 PiperOrigin-RevId: 234218637
This commit is contained in:
parent
e2a04a904c
commit
745028b1fb
@ -194,6 +194,7 @@ Status KernelAndDeviceOp::Run(ScopedStepContainer* step_container,
|
|||||||
params.slice_reader_cache = &slice_reader_cache_;
|
params.slice_reader_cache = &slice_reader_cache_;
|
||||||
params.rendezvous = rendez_;
|
params.rendezvous = rendez_;
|
||||||
params.cancellation_manager = &cm_;
|
params.cancellation_manager = &cm_;
|
||||||
|
cm_.Reset();
|
||||||
params.log_memory = log_memory_;
|
params.log_memory = log_memory_;
|
||||||
std::unique_ptr<StepStatsCollector> step_stats_collector;
|
std::unique_ptr<StepStatsCollector> step_stats_collector;
|
||||||
if (stats != nullptr) {
|
if (stats != nullptr) {
|
||||||
@ -258,6 +259,7 @@ Status KernelAndDeviceFunc::Run(
|
|||||||
opts.rendezvous = nullptr;
|
opts.rendezvous = nullptr;
|
||||||
opts.create_rendezvous = true;
|
opts.create_rendezvous = true;
|
||||||
opts.cancellation_manager = &cm_;
|
opts.cancellation_manager = &cm_;
|
||||||
|
cm_.Reset();
|
||||||
// eager runtime does not yet support collective ops.
|
// eager runtime does not yet support collective ops.
|
||||||
opts.collective_executor = nullptr;
|
opts.collective_executor = nullptr;
|
||||||
opts.allow_dead_tensors = true;
|
opts.allow_dead_tensors = true;
|
||||||
|
@ -27,6 +27,12 @@ CancellationManager::CancellationManager()
|
|||||||
is_cancelled_(false),
|
is_cancelled_(false),
|
||||||
next_cancellation_token_(0) {}
|
next_cancellation_token_(0) {}
|
||||||
|
|
||||||
|
void CancellationManager::Reset() {
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
is_cancelling_ = false;
|
||||||
|
is_cancelled_.store(false);
|
||||||
|
}
|
||||||
|
|
||||||
void CancellationManager::StartCancel() {
|
void CancellationManager::StartCancel() {
|
||||||
gtl::FlatMap<CancellationToken, CancelCallback> callbacks_to_run;
|
gtl::FlatMap<CancellationToken, CancelCallback> callbacks_to_run;
|
||||||
{
|
{
|
||||||
|
@ -56,6 +56,9 @@ class CancellationManager {
|
|||||||
// Returns true iff StartCancel() has been called.
|
// Returns true iff StartCancel() has been called.
|
||||||
bool IsCancelled() { return is_cancelled_.load(std::memory_order_acquire); }
|
bool IsCancelled() { return is_cancelled_.load(std::memory_order_acquire); }
|
||||||
|
|
||||||
|
// Resets the cancellation manager to its original pre-cancelled state.
|
||||||
|
void Reset();
|
||||||
|
|
||||||
// Returns a token that must be used in calls to RegisterCallback
|
// Returns a token that must be used in calls to RegisterCallback
|
||||||
// and DeregisterCallback.
|
// and DeregisterCallback.
|
||||||
CancellationToken get_cancellation_token();
|
CancellationToken get_cancellation_token();
|
||||||
|
@ -26,12 +26,14 @@ from tensorflow.python.eager import def_function
|
|||||||
from tensorflow.python.eager import lift_to_graph
|
from tensorflow.python.eager import lift_to_graph
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_spec
|
from tensorflow.python.framework import tensor_spec
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.keras.engine import training
|
from tensorflow.python.keras.engine import training
|
||||||
from tensorflow.python.keras.layers import core
|
from tensorflow.python.keras.layers import core
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import random_ops
|
from tensorflow.python.ops import random_ops
|
||||||
from tensorflow.python.ops import resource_variable_ops
|
from tensorflow.python.ops import resource_variable_ops
|
||||||
@ -286,6 +288,18 @@ class DefFunctionTest(test.TestCase):
|
|||||||
with self.assertRaisesRegexp(ValueError, 'inner'):
|
with self.assertRaisesRegexp(ValueError, 'inner'):
|
||||||
f(array_ops.zeros(shape=(8, 42, 3)))
|
f(array_ops.zeros(shape=(8, 42, 3)))
|
||||||
|
|
||||||
|
def testRuntimeErrorNotSticky(self):
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def fail(i):
|
||||||
|
control_flow_ops.Assert(math_ops.equal(i, 0), ['ick'])
|
||||||
|
|
||||||
|
fail(constant_op.constant(0)) # OK
|
||||||
|
with self.assertRaises(errors.InvalidArgumentError):
|
||||||
|
fail(constant_op.constant(1)) # InvalidArgument: "ick"
|
||||||
|
fail(constant_op.constant(0)) # OK
|
||||||
|
|
||||||
|
|
||||||
def test_serialization_signature_cache(self):
|
def test_serialization_signature_cache(self):
|
||||||
|
|
||||||
@def_function.function
|
@def_function.function
|
||||||
|
Loading…
Reference in New Issue
Block a user