Fixes sticky cancellation error in tf.function

Fixes #25382

PiperOrigin-RevId: 234218637
This commit is contained in:
Alexandre Passos 2019-02-15 14:56:52 -08:00 committed by TensorFlower Gardener
parent e2a04a904c
commit 745028b1fb
4 changed files with 25 additions and 0 deletions

View File

@ -194,6 +194,7 @@ Status KernelAndDeviceOp::Run(ScopedStepContainer* step_container,
params.slice_reader_cache = &slice_reader_cache_;
params.rendezvous = rendez_;
params.cancellation_manager = &cm_;
cm_.Reset();
params.log_memory = log_memory_;
std::unique_ptr<StepStatsCollector> step_stats_collector;
if (stats != nullptr) {
@ -258,6 +259,7 @@ Status KernelAndDeviceFunc::Run(
opts.rendezvous = nullptr;
opts.create_rendezvous = true;
opts.cancellation_manager = &cm_;
cm_.Reset();
// eager runtime does not yet support collective ops.
opts.collective_executor = nullptr;
opts.allow_dead_tensors = true;

View File

@ -27,6 +27,12 @@ CancellationManager::CancellationManager()
is_cancelled_(false),
next_cancellation_token_(0) {}
void CancellationManager::Reset() {
mutex_lock l(mu_);
is_cancelling_ = false;
is_cancelled_.store(false);
}
void CancellationManager::StartCancel() {
gtl::FlatMap<CancellationToken, CancelCallback> callbacks_to_run;
{

View File

@ -56,6 +56,9 @@ class CancellationManager {
// Returns true iff StartCancel() has been called.
bool IsCancelled() { return is_cancelled_.load(std::memory_order_acquire); }
// Resets the cancellation manager to its original pre-cancelled state.
void Reset();
// Returns a token that must be used in calls to RegisterCallback
// and DeregisterCallback.
CancellationToken get_cancellation_token();

View File

@ -26,12 +26,14 @@ from tensorflow.python.eager import def_function
from tensorflow.python.eager import lift_to_graph
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
from tensorflow.python.keras.engine import training
from tensorflow.python.keras.layers import core
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
@ -286,6 +288,18 @@ class DefFunctionTest(test.TestCase):
with self.assertRaisesRegexp(ValueError, 'inner'):
f(array_ops.zeros(shape=(8, 42, 3)))
def testRuntimeErrorNotSticky(self):
@def_function.function
def fail(i):
control_flow_ops.Assert(math_ops.equal(i, 0), ['ick'])
fail(constant_op.constant(0)) # OK
with self.assertRaises(errors.InvalidArgumentError):
fail(constant_op.constant(1)) # InvalidArgument: "ick"
fail(constant_op.constant(0)) # OK
def test_serialization_signature_cache(self):
@def_function.function