Cancel in-flight closures when there is an error.

PiperOrigin-RevId: 324542620
Change-Id: I1d6cddf8130df74f00ce7b0a3b6b84f553990e78
This commit is contained in:
Yuefeng Zhou 2020-08-02 22:11:18 -07:00 committed by TensorFlower Gardener
parent 37a0da627c
commit 151bd5901a
4 changed files with 252 additions and 238 deletions

View File

@ -32,6 +32,7 @@ py_library(
"//tensorflow/python/distribute:input_lib",
"//tensorflow/python/distribute:parameter_server_strategy_v2",
"//tensorflow/python/distribute:values",
"//tensorflow/python/eager:cancellation",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/eager:executor",

View File

@ -31,15 +31,19 @@ import threading
import weakref
from absl import logging
from six.moves import queue
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import input_lib
from tensorflow.python.distribute import parameter_server_strategy_v2
from tensorflow.python.distribute.client import metric_utils
from tensorflow.python.eager import cancellation
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import executor
from tensorflow.python.eager import function as tf_function
from tensorflow.python.eager import remote
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import func_graph
from tensorflow.python.framework import ops
@ -247,20 +251,28 @@ class PerWorkerValues(object):
self._values = tuple(values)
def _select_worker_slice(worker_id, structured):
"""Selects the worker slice of each of the items in `structured`."""
def _get(x):
return x._values[worker_id] if isinstance(x, PerWorkerValues) else x # pylint: disable=protected-access
return nest.map_structure(_get, structured)
class Closure(object):
"""Hold a function to be scheduled and its arguments."""
def __init__(self, function, args=None, kwargs=None):
def __init__(self, function, cancellation_mgr, args=None, kwargs=None):
if not callable(function):
raise ValueError("Function passed to `Client.schedule` must be a "
"callable object.")
self._args = args or ()
self._kwargs = kwargs or {}
self._function = function
if isinstance(function, def_function.Function):
replica_args = self._select_worker_slice(0, self._args)
replica_kwargs = self._select_worker_slice(0, self._kwargs)
replica_args = _select_worker_slice(0, self._args)
replica_kwargs = _select_worker_slice(0, self._kwargs)
# Note: no need to handle function registration failure since this kind of
# failure will not raise exceptions as designed in the runtime. The client
@ -276,25 +288,22 @@ class Closure(object):
concrete_function = function.get_concrete_function(
*nest.map_structure(_maybe_as_type_spec, replica_args),
**nest.map_structure(_maybe_as_type_spec, replica_kwargs))
self._function = cancellation_mgr.get_cancelable_function(
concrete_function)
self._output_remote_values = nest.map_structure(
lambda x: RemoteValue(self, x), concrete_function.structured_outputs)
elif isinstance(function, tf_function.ConcreteFunction):
self._function = cancellation_mgr.get_cancelable_function(
concrete_function)
self._output_remote_values = nest.map_structure(
lambda x: RemoteValue(self, x), function.structured_outputs)
else:
# Regular python functions.
self._function = function
# TODO(yuefengz): maybe we should trace python functions if their inputs
# are Python primitives, tensors and composite tensors.
self._output_remote_values = RemoteValue(self, None)
def _select_worker_slice(self, worker_id, structured):
"""Selects the worker slice of each of the items in `structured`."""
def _get(x):
return x._values[worker_id] if isinstance(x, PerWorkerValues) else x # pylint: disable=protected-access
return nest.map_structure(_get, structured)
def _fetch_output_remote_values(self):
"""Temporary method used to sync the scheduler."""
# It will do nothing if there is no return value.
@ -319,9 +328,8 @@ class Closure(object):
Args:
worker: a `Worker` object.
"""
replica_args = self._select_worker_slice(worker.worker_index, self._args)
replica_kwargs = self._select_worker_slice(worker.worker_index,
self._kwargs)
replica_args = _select_worker_slice(worker.worker_index, self._args)
replica_kwargs = _select_worker_slice(worker.worker_index, self._kwargs)
e = (
_maybe_get_error_and_rebuild_remote_values(worker, replica_args) or
@ -350,8 +358,7 @@ class _CoordinatedClosureQueue(object):
This class is thread-safe.
"""
def __init__(self):
def __init__(self, cancellation_mgr):
# `self._inflight_closure_count` only tracks the number of inflight closures
# that are "in generation". Once an error occurs, error generation is
# incremented and all subsequent arriving closures (from inflight) are
@ -359,17 +366,26 @@ class _CoordinatedClosureQueue(object):
self._inflight_closure_count = 0
self._queue_lock = threading.Lock()
# Condition indicating that all pending closures (either queued or inflight)
# have been processed, failed, or cancelled.
self._stop_waiting_condition = threading.Condition(self._queue_lock)
# Condition indicating that an item becomes available in queue (not empty).
self._closures_queued_condition = threading.Condition(self._queue_lock)
# Condition indicating that a queue slot becomes available (not full).
# Note that even with "infinite" queue size, there is still a "practical"
# size limit for the queue depending on host memory capacity, and thus the
# queue will eventually become full with a lot of enqueued closures.
self._queue_free_slot_condition = threading.Condition(self._queue_lock)
# Condition indicating there is no inflight closures.
self._no_inflight_closure_condition = threading.Condition(self._queue_lock)
# Use to cancel in-flight closures.
self._cancellation_mgr = cancellation_mgr
if _CLOSURE_QUEUE_MAX_SIZE <= 0:
logging.warning(
"In ParameterServerClient, creating an infinite closure queue can "
@ -377,31 +393,6 @@ class _CoordinatedClosureQueue(object):
self._queue = queue.Queue(maxsize=_CLOSURE_QUEUE_MAX_SIZE)
self._error = None
# Error generation is a counter that helps us track whether a closure
# should be cancelled when it is being put back to `self._queue`. It works
# in the following way:
# 1) Error generation starts off at 0.
# 2) When a worker thread calls `get()`, the closure's error generation
# is copied from this queue's error generation.
# 3) If any worker thread experiences an error that's categorized as a
# non-retryable error, the queue's error will be set, error generation
# increments by 1, and the queue is cleared (with the closures marked
# with cancelled error), so other worker threads stop getting closures
# from the queue. Worker preemption is categorized as a retryable error.
# 4) At this point, if `put()` or `wait()` is called (usually by the main
# thread via `schedule` and `join`), the error is raised through that
# call.
# 5) The closures that are inflight, i.e. that are being executed remotely,
# will not be aware of such error event. If the worker that's executing
# the closure happens to be interrupted, the closure should not be put
# back to the queue, and be cancelled with error instead. Checking the
# generation id of the closure and queue is how the worker thread tells
# whether the closure should be put back. Likewise for `mark_finished`
# and `mark_failed`: if the arriving closure is considered out of
# generation in those two methods, it is simply discarded (the inflight
# closure count still decrements).
self._error_generation = 0
# The following is a lock to make sure when `wait` is called and before it
# returns no `put` can be executed during this period. It is because `wait`
# won't know what to do with newly put closures. This lock adds an cutoff
@ -415,11 +406,14 @@ class _CoordinatedClosureQueue(object):
# of the code.
self._put_wait_lock = threading.Lock()
def _cancel_closures_in_queue(self):
def _cancel_all_closures(self):
"""Clears the queue and sets remaining closures cancelled error.
This method expects self._queue_lock to be held prior to entry.
"""
self._cancellation_mgr.start_cancel()
while self._inflight_closure_count > 0:
self._no_inflight_closure_condition.wait()
while True:
try:
closure = self._queue.get(block=False)
@ -437,8 +431,8 @@ class _CoordinatedClosureQueue(object):
This method expects self._queue_lock to be held prior to entry.
"""
if self._error:
self._cancel_all_closures()
try:
self._cancel_closures_in_queue()
raise self._error # pylint: disable=raising-bad-type
finally:
self._error = None
@ -466,16 +460,17 @@ class _CoordinatedClosureQueue(object):
return None
closure = self._queue.get(block=False)
self._queue_free_slot_condition.notify()
closure._error_generation = self._error_generation # pylint: disable=protected-access
self._inflight_closure_count += 1
return closure
def mark_finished(self, closure):
def mark_finished(self):
"""Let the queue know that a closure has been successfully executed."""
with self._queue_lock:
if self._inflight_closure_count < 1:
raise AssertionError("There is no inflight closures to mark_finished.")
self._inflight_closure_count -= 1
if self._inflight_closure_count == 0:
self._no_inflight_closure_condition.notifyAll()
if self._queue.empty() and self._inflight_closure_count == 0:
self._stop_waiting_condition.notifyAll()
@ -484,17 +479,15 @@ class _CoordinatedClosureQueue(object):
with self._queue_lock:
if self._inflight_closure_count < 1:
raise AssertionError("There is no inflight closures to put_back.")
self._inflight_closure_count -= 1
if closure._error_generation < self._error_generation: # pylint: disable=protected-access
# If the closure to put back is out of generation, cancel the closure
# and ignore it.
logging.info("Function %r should no longer be dispatched; marking "
"as cancelled.")
if self._error:
closure._set_output_remote_values_cancelled() # pylint: disable=protected-access
return
self._queue_free_slot_condition.wait_for(lambda: not self._queue.full())
self._queue.put(closure, block=False)
self._closures_queued_condition.notify()
else:
self._queue_free_slot_condition.wait_for(lambda: not self._queue.full())
self._queue.put(closure, block=False)
self._closures_queued_condition.notify()
self._inflight_closure_count -= 1
if self._inflight_closure_count == 0:
self._no_inflight_closure_condition.notifyAll()
def wait(self, timeout=None):
"""Wait for all closures to be finished before returning.
@ -516,22 +509,18 @@ class _CoordinatedClosureQueue(object):
self._raise_if_error()
return True
def mark_failed(self, e, closure):
def mark_failed(self, e):
"""Sets error and unblocks any wait() call."""
with self._queue_lock:
# TODO(yuefengz): maybe record all failure and give users more
# information?
if self._inflight_closure_count < 1:
raise AssertionError("There is no inflight closures to mark_failed.")
if self._error is None:
self._error = e
self._inflight_closure_count -= 1
if closure._error_generation < self._error_generation: # pylint: disable=protected-access
# If the closure to mark fail is out of generation, simply ignore it
# (with the actual error associated with the closure preserved).
return
assert self._error is None
self._error = e
self._error_generation += 1
self._cancel_closures_in_queue()
if self._inflight_closure_count == 0:
self._no_inflight_closure_condition.notifyAll()
self._stop_waiting_condition.notifyAll()
def done(self):
@ -678,7 +667,7 @@ class Worker(object):
# TODO(yuefengz): we don't have to materialize results every step.
with metric_utils.monitored_timer("remote_value_fetch"):
closure._fetch_output_remote_values() # pylint: disable=protected-access
self._cluster._closure_queue.mark_finished(closure) # pylint: disable=protected-access
self._cluster._closure_queue.mark_finished() # pylint: disable=protected-access
except Exception as e: # pylint: disable=broad-except
logging.error(
"/job:worker/task:%d encountered the following error when processing "
@ -686,7 +675,7 @@ class Worker(object):
nest.map_structure(
lambda x: x._set_error(e), # pylint: disable=protected-access
closure._output_remote_values) # pylint: disable=protected-access
self._cluster._closure_queue.mark_failed(e, closure) # pylint: disable=protected-access
self._cluster._closure_queue.mark_failed(e) # pylint: disable=protected-access
def _process_queue(self):
while True:
@ -710,7 +699,8 @@ class Worker(object):
# the same worker such as creating resources, setting resources' aborted
# status, and executing closures happen on the same thread. This allows us
# to have simpler logic of concurrency.
closure = Closure(function=function, args=args, kwargs=kwargs)
closure = Closure(
function, self._cluster._cancellation_mgr, args=args, kwargs=kwargs) # pylint: disable=protected-access
resource_remote_value = closure._output_remote_values # pylint: disable=protected-access
self._register_resource(resource_remote_value)
@ -775,7 +765,8 @@ class Cluster(object):
protocol=cluster_resolver.rpc_layer,
cluster_device_filters=device_filters)
self._closure_queue = _CoordinatedClosureQueue()
self._cancellation_mgr = cancellation.CancellationManager()
self._closure_queue = _CoordinatedClosureQueue(self._cancellation_mgr)
self.failure_handler = WorkerPreemptionHandler(context.get_server_def())
worker_device_strings = [
"/job:worker/replica:0/task:%d" % i for i in range(self._num_workers)
@ -796,7 +787,8 @@ class Cluster(object):
Returns:
A structure of `RemoteValue` object.
"""
closure = Closure(function=function, args=args, kwargs=kwargs)
closure = Closure(
function, self._cancellation_mgr, args=args, kwargs=kwargs)
self._closure_queue.put(closure)
return closure._output_remote_values # pylint: disable=protected-access
@ -893,8 +885,8 @@ class Client(object):
function execution to finish and retrieve its output from the remote worker.
`schedule` guarantees that `fn` will be executed on a worker at least once;
it could be more than once if a worker fails and restarts in the middle of
function scheduling. Note that since worker can fail at any point when
it could be more than once if its corresponding worker fails in the middle
of its execution. Note that since worker can fail at any point when
executing the function, it is possible that the function is partially
executed, but `Client` guarantees that in those events, the function will
eventually be fully executed, possibly on a different worker that is
@ -904,14 +896,12 @@ class Client(object):
by raising any one of those errors, and clear the errors collected so far.
There are two implications when this happens: 1) user should call `schedule`
with `fn` again to re-schedule, and 2) some of the previously scheduled
functions may no longer execute. User can call `fetch` on the returned
functions may have not been executed. User can call `fetch` on the returned
`RemoteValue` to inspect if they have executed, failed, or cancelled, and
reschedule the corresponding function if needed.
When `schedule` raises, it is possible that there are still functions being
executed on workers, at the time `schedule` raises. When this happens, users
can call `join` again to wait for all pending async function execution to
finish, and bring the cluster into a consistent state.
When `schedule` raises, it guarantees that there is no function that is
still being executed.
At this time, there is no support of worker assignment for function
execution, or priority of the workers.
@ -940,7 +930,8 @@ class Client(object):
# TODO(b/160702436): Invoke `strategy.run` for user's function so it enters
# a `ReplicaContext` in a logically correct way.
with distribute_lib.ReplicaContext(
self._strategy, replica_id_in_sync_group=0):
self._strategy,
replica_id_in_sync_group=constant_op.constant(0, dtypes.int32)):
with self._translate_parameter_server_failure():
return self.cluster.schedule(fn, args=args, kwargs=kwargs)
@ -949,17 +940,14 @@ class Client(object):
If any previously scheduled function raises an error, `join` will fail by
raising any one of those errors, and clear the errors collected so far. If
this happens, some of the previously scheduled functions may no longer
execute. Users can call `fetch` on the returned `RemoteValue` to inspect if
this happens, some of the previously scheduled functions may have not been
executed. Users can call `fetch` on the returned `RemoteValue` to inspect if
they have executed, failed, or cancelled. If some that have been cancelled
need to be rescheduled, users should call `schedule` with the function
again.
Note: `join` raises an exception as soon as the client detects one, and this
means it is possible that there are still functions being executed on
workers, at the time `join` raises. When this happens, users can call `join`
again to wait for all pending async function execution to finish, and bring
the cluster into a consistent state.
When `join` returns or raises, it guarantees that there is no function that
is still being executed.
Raises:
Exception: one of the exceptions caught by the client by any previously
@ -976,6 +964,9 @@ class Client(object):
If any previously scheduled function raises an error, `done` will fail by
raising any one of those errors.
When `done` returns True or raises, it guarantees that there is no function
that is still being executed.
"""
return self.cluster.done()
@ -1091,7 +1082,7 @@ class Client(object):
raise
class _PerWorkerDistributedDataset(object): # pylint: disable=protected-access
class _PerWorkerDistributedDataset(object):
"""Represents worker-distributed datasets created from dataset function."""
def __init__(self, dataset_fn, input_workers, client):
@ -1107,13 +1098,13 @@ class _PerWorkerDistributedDataset(object): # pylint: disable=protected-access
if isinstance(dataset_fn, def_function.Function):
with variable_scope.variable_creator_scope(disallow_variable_creation):
self._dataset_fn = dataset_fn.get_concrete_function()
elif isinstance(dataset_fn, tf_function.ConcreteFunction):
self._dataset_fn = dataset_fn
else:
dataset_fn = dataset_fn.get_concrete_function()
elif not isinstance(dataset_fn, tf_function.ConcreteFunction):
with variable_scope.variable_creator_scope(disallow_variable_creation):
self._dataset_fn = def_function.function(
dataset_fn).get_concrete_function()
dataset_fn = def_function.function(dataset_fn).get_concrete_function()
self._dataset_fn = (
client.cluster._cancellation_mgr.get_cancelable_function( # pylint: disable=protected-access
dataset_fn))
self._input_workers = input_workers
self._client = client
self._element_spec = None

View File

@ -30,22 +30,34 @@ from tensorflow.python.training import coordinator
from tensorflow.python.util import nest
class MockCancellationManager(object):
def __init__(self):
self.cancelled = False
def start_cancel(self):
self.cancelled = True
def get_cancelable_function(self, func):
return func
class CoordinatedClosureQueueTest(test.TestCase):
def testBasic(self):
queue = client._CoordinatedClosureQueue()
queue = client._CoordinatedClosureQueue(MockCancellationManager())
closure1 = self._create_closure()
queue.put(closure1)
self.assertIs(closure1, queue.get())
self.assertFalse(queue.done())
queue.put_back(closure1)
self.assertEqual(closure1, queue.get())
queue.mark_finished(closure1)
queue.mark_finished()
self.assertTrue(queue.done())
queue.wait()
def testProcessAtLeaseOnce(self):
closure_queue = client._CoordinatedClosureQueue()
closure_queue = client._CoordinatedClosureQueue(MockCancellationManager())
labels = ['A', 'B', 'C', 'D', 'E']
processed_count = collections.defaultdict(int)
@ -63,7 +75,7 @@ class CoordinatedClosureQueueTest(test.TestCase):
closure_queue.put_back(closure)
continue
closure._function()
closure_queue.mark_finished(closure)
closure_queue.mark_finished()
def get_func(label):
@ -76,7 +88,8 @@ class CoordinatedClosureQueueTest(test.TestCase):
return func
for label in labels:
closure_queue.put(client.Closure(get_func(label)))
closure_queue.put(
client.Closure(get_func(label), MockCancellationManager()))
t1 = threading.Thread(target=process_queue, daemon=True)
t1.start()
t2 = threading.Thread(target=process_queue, daemon=True)
@ -93,7 +106,7 @@ class CoordinatedClosureQueueTest(test.TestCase):
coord.join([t1, t2])
def testNotifyBeforeWait(self):
closure_queue = client._CoordinatedClosureQueue()
closure_queue = client._CoordinatedClosureQueue(MockCancellationManager())
def func():
logging.info('func running')
@ -102,10 +115,10 @@ class CoordinatedClosureQueueTest(test.TestCase):
def process_queue():
with coord.stop_on_exception():
closure = closure_queue.get()
closure_queue.mark_finished(closure)
closure_queue.get()
closure_queue.mark_finished()
closure_queue.put(client.Closure(func))
closure_queue.put(client.Closure(func, MockCancellationManager()))
t = threading.Thread(target=process_queue)
t.start()
coord.join([t])
@ -114,8 +127,30 @@ class CoordinatedClosureQueueTest(test.TestCase):
# doesn't time out.
closure_queue.wait()
def _assert_one_unblock_the_other(self, first_fn, second_fn):
"""Asserts `second_fn` wouldn't return before `first_fn` is finished."""
first_fn_done = threading.Event()
second_fn_done = threading.Event()
coord = coordinator.Coordinator(clean_stop_exception_types=[])
def wrapped_first_fn():
with coord.stop_on_exception():
self.assertFalse(second_fn_done.is_set())
first_fn()
first_fn_done.set()
self.assertFalse(first_fn_done.is_set())
t = threading.Thread(target=wrapped_first_fn)
t.start()
second_fn()
self.assertTrue(first_fn_done.is_set())
second_fn_done.set()
coord.join([t])
def testWaitRaiseErrorAfterMarkFailure(self):
closure_queue = client._CoordinatedClosureQueue()
closure_queue = client._CoordinatedClosureQueue(MockCancellationManager())
closure_queue.put(self._create_closure())
closure = closure_queue.get()
@ -126,22 +161,17 @@ class CoordinatedClosureQueueTest(test.TestCase):
# all inflight closures are finished.
def mark_finished_fn():
with coord.stop_on_exception():
self.assertFalse(wait_finish_event.is_set())
try:
raise ValueError('Some error.')
except ValueError as e:
closure_queue.mark_failed(e, closure)
wait_finish_event.wait()
try:
raise ValueError('Some error.')
except ValueError as e:
closure_queue.mark_failed(e)
t = threading.Thread(target=mark_finished_fn)
t.start()
def wait_fn():
with self.assertRaises(ValueError):
closure_queue.wait()
with self.assertRaises(ValueError):
closure_queue.wait()
wait_finish_event.set()
self._assert_one_unblock_the_other(mark_finished_fn, wait_fn)
coord.join([t])
self.assertTrue(closure_queue.done())
def _create_closure(self):
@ -150,10 +180,10 @@ class CoordinatedClosureQueueTest(test.TestCase):
def some_function():
return 1.0
return client.Closure(some_function)
return client.Closure(some_function, MockCancellationManager())
def _put_two_closures_and_get_one(self):
closure_queue = client._CoordinatedClosureQueue()
closure_queue = client._CoordinatedClosureQueue(MockCancellationManager())
closure1 = self._create_closure()
closure_queue.put(closure1)
@ -166,9 +196,9 @@ class CoordinatedClosureQueueTest(test.TestCase):
return closure_queue, closure1, closure2
def testPutRaiseError(self):
closure_queue, closure1, closure2 = self._put_two_closures_and_get_one()
closure_queue, _, closure2 = self._put_two_closures_and_get_one()
closure_queue.mark_failed(ValueError(), closure1)
closure_queue.mark_failed(ValueError())
with self.assertRaises(ValueError):
closure_queue.put(self._create_closure())
@ -185,9 +215,9 @@ class CoordinatedClosureQueueTest(test.TestCase):
closure_queue.put(self._create_closure())
def testWaitRaiseError(self):
closure_queue, closure1, closure2 = self._put_two_closures_and_get_one()
closure_queue, _, closure2 = self._put_two_closures_and_get_one()
closure_queue.mark_failed(ValueError(), closure1)
closure_queue.mark_failed(ValueError())
with self.assertRaises(ValueError):
closure_queue.wait()
@ -203,15 +233,22 @@ class CoordinatedClosureQueueTest(test.TestCase):
closure_queue.wait()
def testDoneRaiseError(self):
closure_queue, closure1, _ = self._put_two_closures_and_get_one()
closure_queue.get()
closure_queue, _, _ = self._put_two_closures_and_get_one()
self.assertFalse(closure_queue.done())
closure_queue.mark_failed(ValueError(), closure1)
closure_queue.mark_failed(ValueError())
with self.assertRaises(ValueError):
closure_queue.done()
def _test_error_reporting_and_cancel_flow(self, call_wait):
def _set_error(self, closure_queue, closure, error):
try:
raise error
except Exception as e: # pylint: disable=broad-except
nest.map_structure(lambda x: x._set_error(e),
closure._output_remote_values)
closure_queue.mark_failed(e)
def _test_cancel_closure_when_error(self, call_wait):
closure_queue, closure1, closure2 = self._put_two_closures_and_get_one()
closure_queue.put(self._create_closure())
closure_queue.get()
@ -219,34 +256,37 @@ class CoordinatedClosureQueueTest(test.TestCase):
self.assertEqual(closure_queue._inflight_closure_count, 2)
# Simulating closure1 fails.
try:
raise ValueError('Some error.')
except ValueError as e:
nest.map_structure(lambda x: x._set_error(e),
closure1._output_remote_values)
self.assertEqual(closure_queue._error_generation, 0) # pylint: disable=g-assert-in-except
closure_queue.mark_failed(e, closure1)
self.assertEqual(closure_queue._error_generation, 1)
# At this moment, there are one inflight, nothing
# in queue (because the ones in queue should have been removed and
# cancelled).
self.assertTrue(closure_queue._queue.empty())
# Doesn't include out of generation closures.
self._set_error(closure_queue, closure1, ValueError('Some error.'))
# At this moment, there are one inflight, one in queue.
self.assertEqual(closure_queue._queue.qsize(), 1)
self.assertEqual(closure_queue._inflight_closure_count, 1)
coord = coordinator.Coordinator(clean_stop_exception_types=[])
closure3 = self._create_closure()
with self.assertRaises(ValueError):
# Verifying `wait()` or `put()` raises even if one closure is in
# flight.
if call_wait:
closure_queue.wait()
else:
closure_queue.put(closure3)
# At this moment, there is one inflight, nothing in queue.
def fake_cancellation():
self._set_error(closure_queue, closure2,
ValueError('Fake cancellation error.'))
def report_error():
# It should not report the fake cancellation error.
with self.assertRaisesRegex(ValueError, 'Some error.'):
# Verifying `wait()` or `put()` raises even if one closure is in
# flight.
if call_wait:
closure_queue.wait()
else:
closure_queue.put(closure3)
self._assert_one_unblock_the_other(fake_cancellation, report_error)
# Cancellation manager has been called.
self.assertTrue(closure_queue._cancellation_mgr.cancelled)
# At this moment, there is zero inflight, nothing in queue.
self.assertTrue(closure_queue._queue.empty())
self.assertEqual(closure_queue._inflight_closure_count, 1)
self.assertEqual(closure_queue._inflight_closure_count, 0)
self.assertIsNone(closure_queue._error)
# This asserts that closure1 has errored.
with self.assertRaisesRegex(ValueError, 'Some error.'):
@ -260,107 +300,36 @@ class CoordinatedClosureQueueTest(test.TestCase):
'function.'):
closure3._fetch_output_remote_values()
# Closure2 is inflight, so it shouldn't be ready.
# Closure2 was an inflight closure when it got cancelled.
self.assertEqual(closure2._output_remote_values._status,
client._RemoteValueStatus.NOT_READY)
# And `wait` should block because closure2 is not back yet.
self.assertFalse(closure_queue.wait(timeout=20))
# Now let's assume that closure2 isn't successful due to worker preemption,
# and now it's attempted to be put back, but ends up getting cancelled.
self.assertEqual(closure2._error_generation, 0)
self.assertEqual(closure_queue._error_generation, 1)
closure_queue.put_back(closure2)
with self.assertRaisesRegex(
client.FunctionRetryableError,
'The corresponding function is cancelled. Please reschedule the '
'function.'):
client._RemoteValueStatus.READY)
with self.assertRaisesRegex(ValueError, 'Fake cancellation error.'):
closure2._fetch_output_remote_values()
# At this moment, there is nothing inflight, and the queue is also empty
# (because closure2 should not be added back to the queue).
self.assertTrue(closure_queue._queue.empty())
self.assertEqual(closure_queue._inflight_closure_count, 0)
# This asserts that the queue has a clear state.
self.testBasic()
closure4 = self._create_closure()
def testWaitRaiseErrorAfterCancelClosure(self):
self._test_cancel_closure_when_error(call_wait=True)
e = threading.Event()
def get_fn():
with coord.stop_on_exception():
# This should end up getting closure4, not closure2, because closure2
# has been cancelled and should not be got.
closure_got = closure_queue.get()
e.set()
self.assertEqual(closure_got._error_generation, 1)
self.assertEqual(closure_queue._error_generation, 1)
self.assertIs(closure4, closure_got)
self.assertIsNot(closure2, closure_got)
t = threading.Thread(target=get_fn)
t.start()
time.sleep(10)
# Make sure `closure_got = closure_queue.get()` is unblocked as a result of
# `closure_queue.put(closure4)`.
self.assertFalse(e.is_set())
closure_queue.put(closure4)
self.assertTrue(e.wait())
coord.join([t])
self.assertEqual(closure_queue._inflight_closure_count, 1)
closure_queue.mark_finished(closure4)
# The queue is now cleared and nothing inflight.
self.assertEqual(closure_queue._inflight_closure_count, 0)
closure_queue.wait()
def testWaitRaiseErrorAfterAnErrorIsReported(self):
self._test_error_reporting_and_cancel_flow(call_wait=True)
def testPutRaiseErrorAfterAnErrorIsReported(self):
self._test_error_reporting_and_cancel_flow(call_wait=False)
def testPutRaiseErrorAfterCancelClosure(self):
self._test_cancel_closure_when_error(call_wait=False)
def testStateIsRestoredAfterJoinIsCalled(self):
closure_queue, closure1, closure2 = self._put_two_closures_and_get_one()
closure_queue.get()
self.assertEqual(closure_queue._inflight_closure_count, 2)
closure_queue.mark_failed(ValueError('test error'), closure1)
closure_queue, _, _ = self._put_two_closures_and_get_one()
self.assertEqual(closure_queue._inflight_closure_count, 1)
closure_queue.mark_failed(ValueError('test error'))
with self.assertRaises(ValueError):
closure_queue.put(self._create_closure())
closure_queue.mark_failed(ValueError('test error'), closure2)
# closure2's error is previous generation so should not raise at this
# following put, and _error should have been cleared.
# Its error should have been cleared.
self.assertIsNone(closure_queue._error)
closure_queue.put(self._create_closure())
self.assertIsNone(closure_queue._error)
def testStateIsRestoredAfterJoinIsCalled_WaitShouldReturn(self):
closure_queue, closure1, closure2 = self._put_two_closures_and_get_one()
closure_queue.put(self._create_closure())
closure_queue.get() # got closure2
self.assertFalse(closure_queue._queue.empty()) # still has closure3
self.assertEqual(closure_queue._inflight_closure_count, 2) # closure1,2
closure_queue.mark_failed(ValueError('test error'), closure1)
self.assertTrue(closure_queue._queue.empty()) # closure3 cancelled
self.assertEqual(closure_queue._inflight_closure_count, 1)
with self.assertRaises(ValueError):
closure_queue.wait() # reports error from closure1
# `wait` should block because closure2 is not back yet, even if closure2
# was sent inflight before the error.
self.assertFalse(closure_queue.wait(timeout=20))
self.assertEqual(closure_queue._inflight_closure_count, 1)
closure_queue.mark_finished(closure2)
closure_queue.wait() # wait should pass immediately
self.assertEqual(closure_queue._inflight_closure_count, 0)
def testThreadSafey(self):
thread_count = 10
queue = client._CoordinatedClosureQueue()
queue = client._CoordinatedClosureQueue(MockCancellationManager())
# Each thread performs 20 queue actions: 10 are `put_back` and 10 are
# `mark_finished`.
@ -372,7 +341,7 @@ class CoordinatedClosureQueueTest(test.TestCase):
if i % 2 == 0:
queue.put_back(closure)
else:
queue.mark_finished(closure)
queue.mark_finished()
threads = [threading.Thread(target=func) for i in range(thread_count)]
for t in threads:

View File

@ -19,7 +19,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import threading
from absl import logging
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.distribute import sharded_variable
@ -40,6 +43,48 @@ from tensorflow.python.ops import variables
from tensorflow.python.training.server_lib import ClusterSpec
class ErrorReportingThread(threading.Thread):
error = None
def __init__(self, *args, **kwargs):
assert "target" in kwargs
target = kwargs["target"]
@functools.wraps(target)
def wrapped_target(*args, **kwargs):
try:
return target(*args, **kwargs)
except Exception as e: # pylint: disable=broad-except
ErrorReportingThread.error = e
kwargs["target"] = wrapped_target
super(ErrorReportingThread, self).__init__(*args, **kwargs)
class TestCaseWithErrorReportingThread(test.TestCase):
@classmethod
def setUpClass(cls):
cls._threading_thread = threading.Thread
threading.Thread = ErrorReportingThread
super(TestCaseWithErrorReportingThread, cls).setUpClass()
@classmethod
def tearDownClass(cls):
super(TestCaseWithErrorReportingThread, cls).tearDownClass()
threading.Thread = cls._threading_thread
def setUp(self):
ErrorReportingThread.error = None
super(TestCaseWithErrorReportingThread, self).setUp()
def tearDown(self):
super(TestCaseWithErrorReportingThread, self).tearDown()
if ErrorReportingThread.error:
raise ErrorReportingThread.error # pylint: disable=raising-bad-type
def make_client(num_workers, num_ps):
# TODO(rchao): Test the internal rpc_layer version.
cluster_def = multi_worker_test_base.create_in_process_cluster(
@ -52,7 +97,7 @@ def make_client(num_workers, num_ps):
return parameter_server_client.ParameterServerClient(cluster_resolver)
class ParameterServerClientTest(test.TestCase):
class ParameterServerClientTest(TestCaseWithErrorReportingThread):
@classmethod
def setUpClass(cls):
@ -304,7 +349,7 @@ class VariablePartitioningScopeTest(test.TestCase):
self.assertEqual(var_sum, 10.0)
class ErrorReportingTest(test.TestCase):
class ErrorReportingTest(TestCaseWithErrorReportingThread):
@classmethod
def setUpClass(cls):
@ -344,8 +389,16 @@ class ErrorReportingTest(test.TestCase):
while True:
self.client.schedule(self._normal_function)
def testScheduleRaiseErrorWithMultipleFailure(self):
for _ in range(3):
self.client.schedule(self._normal_function)
self.client.schedule(self._error_function)
with self.assertRaises(errors.InvalidArgumentError):
while True:
self.client.schedule(self._error_function)
self.client.join()
def testErrorWillbeCleared(self):
self.skipTest("b/157597579")
self.client.schedule(self._error_function)
with self.assertRaises(errors.InvalidArgumentError):
self.client.join()
@ -356,7 +409,7 @@ class ErrorReportingTest(test.TestCase):
with self.assertRaises(errors.InvalidArgumentError):
self.client.join()
def testFutureReturnError(self):
def testRemoteValueReturnError(self):
result = self.client.schedule(self._error_function)
with self.assertRaises(errors.InvalidArgumentError):