Fixes sticky cancellation error in tf.function
Fixes #25382 PiperOrigin-RevId: 234218637
This commit is contained in:
parent
e2a04a904c
commit
745028b1fb
@ -194,6 +194,7 @@ Status KernelAndDeviceOp::Run(ScopedStepContainer* step_container,
|
||||
params.slice_reader_cache = &slice_reader_cache_;
|
||||
params.rendezvous = rendez_;
|
||||
params.cancellation_manager = &cm_;
|
||||
cm_.Reset();
|
||||
params.log_memory = log_memory_;
|
||||
std::unique_ptr<StepStatsCollector> step_stats_collector;
|
||||
if (stats != nullptr) {
|
||||
@ -258,6 +259,7 @@ Status KernelAndDeviceFunc::Run(
|
||||
opts.rendezvous = nullptr;
|
||||
opts.create_rendezvous = true;
|
||||
opts.cancellation_manager = &cm_;
|
||||
cm_.Reset();
|
||||
// eager runtime does not yet support collective ops.
|
||||
opts.collective_executor = nullptr;
|
||||
opts.allow_dead_tensors = true;
|
||||
|
@ -27,6 +27,12 @@ CancellationManager::CancellationManager()
|
||||
is_cancelled_(false),
|
||||
next_cancellation_token_(0) {}
|
||||
|
||||
void CancellationManager::Reset() {
|
||||
mutex_lock l(mu_);
|
||||
is_cancelling_ = false;
|
||||
is_cancelled_.store(false);
|
||||
}
|
||||
|
||||
void CancellationManager::StartCancel() {
|
||||
gtl::FlatMap<CancellationToken, CancelCallback> callbacks_to_run;
|
||||
{
|
||||
|
@ -56,6 +56,9 @@ class CancellationManager {
|
||||
// Returns true iff StartCancel() has been called.
|
||||
bool IsCancelled() { return is_cancelled_.load(std::memory_order_acquire); }
|
||||
|
||||
// Resets the cancellation manager to its original pre-cancelled state.
|
||||
void Reset();
|
||||
|
||||
// Returns a token that must be used in calls to RegisterCallback
|
||||
// and DeregisterCallback.
|
||||
CancellationToken get_cancellation_token();
|
||||
|
@ -26,12 +26,14 @@ from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import lift_to_graph
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras.engine import training
|
||||
from tensorflow.python.keras.layers import core
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
@ -286,6 +288,18 @@ class DefFunctionTest(test.TestCase):
|
||||
with self.assertRaisesRegexp(ValueError, 'inner'):
|
||||
f(array_ops.zeros(shape=(8, 42, 3)))
|
||||
|
||||
def testRuntimeErrorNotSticky(self):
|
||||
|
||||
@def_function.function
|
||||
def fail(i):
|
||||
control_flow_ops.Assert(math_ops.equal(i, 0), ['ick'])
|
||||
|
||||
fail(constant_op.constant(0)) # OK
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
fail(constant_op.constant(1)) # InvalidArgument: "ick"
|
||||
fail(constant_op.constant(0)) # OK
|
||||
|
||||
|
||||
def test_serialization_signature_cache(self):
|
||||
|
||||
@def_function.function
|
||||
|
Loading…
Reference in New Issue
Block a user