Cancel in-flight closures when there is an error.
PiperOrigin-RevId: 324542620 Change-Id: I1d6cddf8130df74f00ce7b0a3b6b84f553990e78
This commit is contained in:
parent
37a0da627c
commit
151bd5901a
@ -32,6 +32,7 @@ py_library(
|
||||
"//tensorflow/python/distribute:input_lib",
|
||||
"//tensorflow/python/distribute:parameter_server_strategy_v2",
|
||||
"//tensorflow/python/distribute:values",
|
||||
"//tensorflow/python/eager:cancellation",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//tensorflow/python/eager:executor",
|
||||
|
@ -31,15 +31,19 @@ import threading
|
||||
import weakref
|
||||
from absl import logging
|
||||
from six.moves import queue
|
||||
|
||||
from tensorflow.python.distribute import distribute_lib
|
||||
from tensorflow.python.distribute import input_lib
|
||||
from tensorflow.python.distribute import parameter_server_strategy_v2
|
||||
from tensorflow.python.distribute.client import metric_utils
|
||||
from tensorflow.python.eager import cancellation
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import executor
|
||||
from tensorflow.python.eager import function as tf_function
|
||||
from tensorflow.python.eager import remote
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import func_graph
|
||||
from tensorflow.python.framework import ops
|
||||
@ -247,20 +251,28 @@ class PerWorkerValues(object):
|
||||
self._values = tuple(values)
|
||||
|
||||
|
||||
def _select_worker_slice(worker_id, structured):
|
||||
"""Selects the worker slice of each of the items in `structured`."""
|
||||
|
||||
def _get(x):
|
||||
return x._values[worker_id] if isinstance(x, PerWorkerValues) else x # pylint: disable=protected-access
|
||||
|
||||
return nest.map_structure(_get, structured)
|
||||
|
||||
|
||||
class Closure(object):
|
||||
"""Hold a function to be scheduled and its arguments."""
|
||||
|
||||
def __init__(self, function, args=None, kwargs=None):
|
||||
def __init__(self, function, cancellation_mgr, args=None, kwargs=None):
|
||||
if not callable(function):
|
||||
raise ValueError("Function passed to `Client.schedule` must be a "
|
||||
"callable object.")
|
||||
self._args = args or ()
|
||||
self._kwargs = kwargs or {}
|
||||
self._function = function
|
||||
|
||||
if isinstance(function, def_function.Function):
|
||||
replica_args = self._select_worker_slice(0, self._args)
|
||||
replica_kwargs = self._select_worker_slice(0, self._kwargs)
|
||||
replica_args = _select_worker_slice(0, self._args)
|
||||
replica_kwargs = _select_worker_slice(0, self._kwargs)
|
||||
|
||||
# Note: no need to handle function registration failure since this kind of
|
||||
# failure will not raise exceptions as designed in the runtime. The client
|
||||
@ -276,25 +288,22 @@ class Closure(object):
|
||||
concrete_function = function.get_concrete_function(
|
||||
*nest.map_structure(_maybe_as_type_spec, replica_args),
|
||||
**nest.map_structure(_maybe_as_type_spec, replica_kwargs))
|
||||
self._function = cancellation_mgr.get_cancelable_function(
|
||||
concrete_function)
|
||||
self._output_remote_values = nest.map_structure(
|
||||
lambda x: RemoteValue(self, x), concrete_function.structured_outputs)
|
||||
elif isinstance(function, tf_function.ConcreteFunction):
|
||||
self._function = cancellation_mgr.get_cancelable_function(
|
||||
concrete_function)
|
||||
self._output_remote_values = nest.map_structure(
|
||||
lambda x: RemoteValue(self, x), function.structured_outputs)
|
||||
else:
|
||||
# Regular python functions.
|
||||
self._function = function
|
||||
# TODO(yuefengz): maybe we should trace python functions if their inputs
|
||||
# are Python primitives, tensors and composite tensors.
|
||||
self._output_remote_values = RemoteValue(self, None)
|
||||
|
||||
def _select_worker_slice(self, worker_id, structured):
|
||||
"""Selects the worker slice of each of the items in `structured`."""
|
||||
|
||||
def _get(x):
|
||||
return x._values[worker_id] if isinstance(x, PerWorkerValues) else x # pylint: disable=protected-access
|
||||
|
||||
return nest.map_structure(_get, structured)
|
||||
|
||||
def _fetch_output_remote_values(self):
|
||||
"""Temporary method used to sync the scheduler."""
|
||||
# It will do nothing if there is no return value.
|
||||
@ -319,9 +328,8 @@ class Closure(object):
|
||||
Args:
|
||||
worker: a `Worker` object.
|
||||
"""
|
||||
replica_args = self._select_worker_slice(worker.worker_index, self._args)
|
||||
replica_kwargs = self._select_worker_slice(worker.worker_index,
|
||||
self._kwargs)
|
||||
replica_args = _select_worker_slice(worker.worker_index, self._args)
|
||||
replica_kwargs = _select_worker_slice(worker.worker_index, self._kwargs)
|
||||
|
||||
e = (
|
||||
_maybe_get_error_and_rebuild_remote_values(worker, replica_args) or
|
||||
@ -350,8 +358,7 @@ class _CoordinatedClosureQueue(object):
|
||||
This class is thread-safe.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
|
||||
def __init__(self, cancellation_mgr):
|
||||
# `self._inflight_closure_count` only tracks the number of inflight closures
|
||||
# that are "in generation". Once an error occurs, error generation is
|
||||
# incremented and all subsequent arriving closures (from inflight) are
|
||||
@ -359,17 +366,26 @@ class _CoordinatedClosureQueue(object):
|
||||
self._inflight_closure_count = 0
|
||||
|
||||
self._queue_lock = threading.Lock()
|
||||
|
||||
# Condition indicating that all pending closures (either queued or inflight)
|
||||
# have been processed, failed, or cancelled.
|
||||
self._stop_waiting_condition = threading.Condition(self._queue_lock)
|
||||
|
||||
# Condition indicating that an item becomes available in queue (not empty).
|
||||
self._closures_queued_condition = threading.Condition(self._queue_lock)
|
||||
|
||||
# Condition indicating that a queue slot becomes available (not full).
|
||||
# Note that even with "infinite" queue size, there is still a "practical"
|
||||
# size limit for the queue depending on host memory capacity, and thus the
|
||||
# queue will eventually become full with a lot of enqueued closures.
|
||||
self._queue_free_slot_condition = threading.Condition(self._queue_lock)
|
||||
|
||||
# Condition indicating there is no inflight closures.
|
||||
self._no_inflight_closure_condition = threading.Condition(self._queue_lock)
|
||||
|
||||
# Use to cancel in-flight closures.
|
||||
self._cancellation_mgr = cancellation_mgr
|
||||
|
||||
if _CLOSURE_QUEUE_MAX_SIZE <= 0:
|
||||
logging.warning(
|
||||
"In ParameterServerClient, creating an infinite closure queue can "
|
||||
@ -377,31 +393,6 @@ class _CoordinatedClosureQueue(object):
|
||||
self._queue = queue.Queue(maxsize=_CLOSURE_QUEUE_MAX_SIZE)
|
||||
self._error = None
|
||||
|
||||
# Error generation is a counter that helps us track whether a closure
|
||||
# should be cancelled when it is being put back to `self._queue`. It works
|
||||
# in the following way:
|
||||
# 1) Error generation starts off at 0.
|
||||
# 2) When a worker thread calls `get()`, the closure's error generation
|
||||
# is copied from this queue's error generation.
|
||||
# 3) If any worker thread experiences an error that's categorized as a
|
||||
# non-retryable error, the queue's error will be set, error generation
|
||||
# increments by 1, and the queue is cleared (with the closures marked
|
||||
# with cancelled error), so other worker threads stop getting closures
|
||||
# from the queue. Worker preemption is categorized as a retryable error.
|
||||
# 4) At this point, if `put()` or `wait()` is called (usually by the main
|
||||
# thread via `schedule` and `join`), the error is raised through that
|
||||
# call.
|
||||
# 5) The closures that are inflight, i.e. that are being executed remotely,
|
||||
# will not be aware of such error event. If the worker that's executing
|
||||
# the closure happens to be interrupted, the closure should not be put
|
||||
# back to the queue, and be cancelled with error instead. Checking the
|
||||
# generation id of the closure and queue is how the worker thread tells
|
||||
# whether the closure should be put back. Likewise for `mark_finished`
|
||||
# and `mark_failed`: if the arriving closure is considered out of
|
||||
# generation in those two methods, it is simply discarded (the inflight
|
||||
# closure count still decrements).
|
||||
self._error_generation = 0
|
||||
|
||||
# The following is a lock to make sure when `wait` is called and before it
|
||||
# returns no `put` can be executed during this period. It is because `wait`
|
||||
# won't know what to do with newly put closures. This lock adds an cutoff
|
||||
@ -415,11 +406,14 @@ class _CoordinatedClosureQueue(object):
|
||||
# of the code.
|
||||
self._put_wait_lock = threading.Lock()
|
||||
|
||||
def _cancel_closures_in_queue(self):
|
||||
def _cancel_all_closures(self):
|
||||
"""Clears the queue and sets remaining closures cancelled error.
|
||||
|
||||
This method expects self._queue_lock to be held prior to entry.
|
||||
"""
|
||||
self._cancellation_mgr.start_cancel()
|
||||
while self._inflight_closure_count > 0:
|
||||
self._no_inflight_closure_condition.wait()
|
||||
while True:
|
||||
try:
|
||||
closure = self._queue.get(block=False)
|
||||
@ -437,8 +431,8 @@ class _CoordinatedClosureQueue(object):
|
||||
This method expects self._queue_lock to be held prior to entry.
|
||||
"""
|
||||
if self._error:
|
||||
self._cancel_all_closures()
|
||||
try:
|
||||
self._cancel_closures_in_queue()
|
||||
raise self._error # pylint: disable=raising-bad-type
|
||||
finally:
|
||||
self._error = None
|
||||
@ -466,16 +460,17 @@ class _CoordinatedClosureQueue(object):
|
||||
return None
|
||||
closure = self._queue.get(block=False)
|
||||
self._queue_free_slot_condition.notify()
|
||||
closure._error_generation = self._error_generation # pylint: disable=protected-access
|
||||
self._inflight_closure_count += 1
|
||||
return closure
|
||||
|
||||
def mark_finished(self, closure):
|
||||
def mark_finished(self):
|
||||
"""Let the queue know that a closure has been successfully executed."""
|
||||
with self._queue_lock:
|
||||
if self._inflight_closure_count < 1:
|
||||
raise AssertionError("There is no inflight closures to mark_finished.")
|
||||
self._inflight_closure_count -= 1
|
||||
if self._inflight_closure_count == 0:
|
||||
self._no_inflight_closure_condition.notifyAll()
|
||||
if self._queue.empty() and self._inflight_closure_count == 0:
|
||||
self._stop_waiting_condition.notifyAll()
|
||||
|
||||
@ -484,17 +479,15 @@ class _CoordinatedClosureQueue(object):
|
||||
with self._queue_lock:
|
||||
if self._inflight_closure_count < 1:
|
||||
raise AssertionError("There is no inflight closures to put_back.")
|
||||
self._inflight_closure_count -= 1
|
||||
if closure._error_generation < self._error_generation: # pylint: disable=protected-access
|
||||
# If the closure to put back is out of generation, cancel the closure
|
||||
# and ignore it.
|
||||
logging.info("Function %r should no longer be dispatched; marking "
|
||||
"as cancelled.")
|
||||
if self._error:
|
||||
closure._set_output_remote_values_cancelled() # pylint: disable=protected-access
|
||||
return
|
||||
self._queue_free_slot_condition.wait_for(lambda: not self._queue.full())
|
||||
self._queue.put(closure, block=False)
|
||||
self._closures_queued_condition.notify()
|
||||
else:
|
||||
self._queue_free_slot_condition.wait_for(lambda: not self._queue.full())
|
||||
self._queue.put(closure, block=False)
|
||||
self._closures_queued_condition.notify()
|
||||
self._inflight_closure_count -= 1
|
||||
if self._inflight_closure_count == 0:
|
||||
self._no_inflight_closure_condition.notifyAll()
|
||||
|
||||
def wait(self, timeout=None):
|
||||
"""Wait for all closures to be finished before returning.
|
||||
@ -516,22 +509,18 @@ class _CoordinatedClosureQueue(object):
|
||||
self._raise_if_error()
|
||||
return True
|
||||
|
||||
def mark_failed(self, e, closure):
|
||||
def mark_failed(self, e):
|
||||
"""Sets error and unblocks any wait() call."""
|
||||
with self._queue_lock:
|
||||
# TODO(yuefengz): maybe record all failure and give users more
|
||||
# information?
|
||||
if self._inflight_closure_count < 1:
|
||||
raise AssertionError("There is no inflight closures to mark_failed.")
|
||||
if self._error is None:
|
||||
self._error = e
|
||||
self._inflight_closure_count -= 1
|
||||
if closure._error_generation < self._error_generation: # pylint: disable=protected-access
|
||||
# If the closure to mark fail is out of generation, simply ignore it
|
||||
# (with the actual error associated with the closure preserved).
|
||||
return
|
||||
assert self._error is None
|
||||
self._error = e
|
||||
self._error_generation += 1
|
||||
self._cancel_closures_in_queue()
|
||||
if self._inflight_closure_count == 0:
|
||||
self._no_inflight_closure_condition.notifyAll()
|
||||
self._stop_waiting_condition.notifyAll()
|
||||
|
||||
def done(self):
|
||||
@ -678,7 +667,7 @@ class Worker(object):
|
||||
# TODO(yuefengz): we don't have to materialize results every step.
|
||||
with metric_utils.monitored_timer("remote_value_fetch"):
|
||||
closure._fetch_output_remote_values() # pylint: disable=protected-access
|
||||
self._cluster._closure_queue.mark_finished(closure) # pylint: disable=protected-access
|
||||
self._cluster._closure_queue.mark_finished() # pylint: disable=protected-access
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
logging.error(
|
||||
"/job:worker/task:%d encountered the following error when processing "
|
||||
@ -686,7 +675,7 @@ class Worker(object):
|
||||
nest.map_structure(
|
||||
lambda x: x._set_error(e), # pylint: disable=protected-access
|
||||
closure._output_remote_values) # pylint: disable=protected-access
|
||||
self._cluster._closure_queue.mark_failed(e, closure) # pylint: disable=protected-access
|
||||
self._cluster._closure_queue.mark_failed(e) # pylint: disable=protected-access
|
||||
|
||||
def _process_queue(self):
|
||||
while True:
|
||||
@ -710,7 +699,8 @@ class Worker(object):
|
||||
# the same worker such as creating resources, setting resources' aborted
|
||||
# status, and executing closures happen on the same thread. This allows us
|
||||
# to have simpler logic of concurrency.
|
||||
closure = Closure(function=function, args=args, kwargs=kwargs)
|
||||
closure = Closure(
|
||||
function, self._cluster._cancellation_mgr, args=args, kwargs=kwargs) # pylint: disable=protected-access
|
||||
resource_remote_value = closure._output_remote_values # pylint: disable=protected-access
|
||||
self._register_resource(resource_remote_value)
|
||||
|
||||
@ -775,7 +765,8 @@ class Cluster(object):
|
||||
protocol=cluster_resolver.rpc_layer,
|
||||
cluster_device_filters=device_filters)
|
||||
|
||||
self._closure_queue = _CoordinatedClosureQueue()
|
||||
self._cancellation_mgr = cancellation.CancellationManager()
|
||||
self._closure_queue = _CoordinatedClosureQueue(self._cancellation_mgr)
|
||||
self.failure_handler = WorkerPreemptionHandler(context.get_server_def())
|
||||
worker_device_strings = [
|
||||
"/job:worker/replica:0/task:%d" % i for i in range(self._num_workers)
|
||||
@ -796,7 +787,8 @@ class Cluster(object):
|
||||
Returns:
|
||||
A structure of `RemoteValue` object.
|
||||
"""
|
||||
closure = Closure(function=function, args=args, kwargs=kwargs)
|
||||
closure = Closure(
|
||||
function, self._cancellation_mgr, args=args, kwargs=kwargs)
|
||||
self._closure_queue.put(closure)
|
||||
return closure._output_remote_values # pylint: disable=protected-access
|
||||
|
||||
@ -893,8 +885,8 @@ class Client(object):
|
||||
function execution to finish and retrieve its output from the remote worker.
|
||||
|
||||
`schedule` guarantees that `fn` will be executed on a worker at least once;
|
||||
it could be more than once if a worker fails and restarts in the middle of
|
||||
function scheduling. Note that since worker can fail at any point when
|
||||
it could be more than once if its corresponding worker fails in the middle
|
||||
of its execution. Note that since worker can fail at any point when
|
||||
executing the function, it is possible that the function is partially
|
||||
executed, but `Client` guarantees that in those events, the function will
|
||||
eventually be fully executed, possibly on a different worker that is
|
||||
@ -904,14 +896,12 @@ class Client(object):
|
||||
by raising any one of those errors, and clear the errors collected so far.
|
||||
There are two implications when this happens: 1) user should call `schedule`
|
||||
with `fn` again to re-schedule, and 2) some of the previously scheduled
|
||||
functions may no longer execute. User can call `fetch` on the returned
|
||||
functions may have not been executed. User can call `fetch` on the returned
|
||||
`RemoteValue` to inspect if they have executed, failed, or cancelled, and
|
||||
reschedule the corresponding function if needed.
|
||||
|
||||
When `schedule` raises, it is possible that there are still functions being
|
||||
executed on workers, at the time `schedule` raises. When this happens, users
|
||||
can call `join` again to wait for all pending async function execution to
|
||||
finish, and bring the cluster into a consistent state.
|
||||
When `schedule` raises, it guarantees that there is no function that is
|
||||
still being executed.
|
||||
|
||||
At this time, there is no support of worker assignment for function
|
||||
execution, or priority of the workers.
|
||||
@ -940,7 +930,8 @@ class Client(object):
|
||||
# TODO(b/160702436): Invoke `strategy.run` for user's function so it enters
|
||||
# a `ReplicaContext` in a logically correct way.
|
||||
with distribute_lib.ReplicaContext(
|
||||
self._strategy, replica_id_in_sync_group=0):
|
||||
self._strategy,
|
||||
replica_id_in_sync_group=constant_op.constant(0, dtypes.int32)):
|
||||
with self._translate_parameter_server_failure():
|
||||
return self.cluster.schedule(fn, args=args, kwargs=kwargs)
|
||||
|
||||
@ -949,17 +940,14 @@ class Client(object):
|
||||
|
||||
If any previously scheduled function raises an error, `join` will fail by
|
||||
raising any one of those errors, and clear the errors collected so far. If
|
||||
this happens, some of the previously scheduled functions may no longer
|
||||
execute. Users can call `fetch` on the returned `RemoteValue` to inspect if
|
||||
this happens, some of the previously scheduled functions may have not been
|
||||
executed. Users can call `fetch` on the returned `RemoteValue` to inspect if
|
||||
they have executed, failed, or cancelled. If some that have been cancelled
|
||||
need to be rescheduled, users should call `schedule` with the function
|
||||
again.
|
||||
|
||||
Note: `join` raises an exception as soon as the client detects one, and this
|
||||
means it is possible that there are still functions being executed on
|
||||
workers, at the time `join` raises. When this happens, users can call `join`
|
||||
again to wait for all pending async function execution to finish, and bring
|
||||
the cluster into a consistent state.
|
||||
When `join` returns or raises, it guarantees that there is no function that
|
||||
is still being executed.
|
||||
|
||||
Raises:
|
||||
Exception: one of the exceptions caught by the client by any previously
|
||||
@ -976,6 +964,9 @@ class Client(object):
|
||||
|
||||
If any previously scheduled function raises an error, `done` will fail by
|
||||
raising any one of those errors.
|
||||
|
||||
When `done` returns True or raises, it guarantees that there is no function
|
||||
that is still being executed.
|
||||
"""
|
||||
return self.cluster.done()
|
||||
|
||||
@ -1091,7 +1082,7 @@ class Client(object):
|
||||
raise
|
||||
|
||||
|
||||
class _PerWorkerDistributedDataset(object): # pylint: disable=protected-access
|
||||
class _PerWorkerDistributedDataset(object):
|
||||
"""Represents worker-distributed datasets created from dataset function."""
|
||||
|
||||
def __init__(self, dataset_fn, input_workers, client):
|
||||
@ -1107,13 +1098,13 @@ class _PerWorkerDistributedDataset(object): # pylint: disable=protected-access
|
||||
|
||||
if isinstance(dataset_fn, def_function.Function):
|
||||
with variable_scope.variable_creator_scope(disallow_variable_creation):
|
||||
self._dataset_fn = dataset_fn.get_concrete_function()
|
||||
elif isinstance(dataset_fn, tf_function.ConcreteFunction):
|
||||
self._dataset_fn = dataset_fn
|
||||
else:
|
||||
dataset_fn = dataset_fn.get_concrete_function()
|
||||
elif not isinstance(dataset_fn, tf_function.ConcreteFunction):
|
||||
with variable_scope.variable_creator_scope(disallow_variable_creation):
|
||||
self._dataset_fn = def_function.function(
|
||||
dataset_fn).get_concrete_function()
|
||||
dataset_fn = def_function.function(dataset_fn).get_concrete_function()
|
||||
self._dataset_fn = (
|
||||
client.cluster._cancellation_mgr.get_cancelable_function( # pylint: disable=protected-access
|
||||
dataset_fn))
|
||||
self._input_workers = input_workers
|
||||
self._client = client
|
||||
self._element_spec = None
|
||||
|
@ -30,22 +30,34 @@ from tensorflow.python.training import coordinator
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
class MockCancellationManager(object):
|
||||
|
||||
def __init__(self):
|
||||
self.cancelled = False
|
||||
|
||||
def start_cancel(self):
|
||||
self.cancelled = True
|
||||
|
||||
def get_cancelable_function(self, func):
|
||||
return func
|
||||
|
||||
|
||||
class CoordinatedClosureQueueTest(test.TestCase):
|
||||
|
||||
def testBasic(self):
|
||||
queue = client._CoordinatedClosureQueue()
|
||||
queue = client._CoordinatedClosureQueue(MockCancellationManager())
|
||||
closure1 = self._create_closure()
|
||||
queue.put(closure1)
|
||||
self.assertIs(closure1, queue.get())
|
||||
self.assertFalse(queue.done())
|
||||
queue.put_back(closure1)
|
||||
self.assertEqual(closure1, queue.get())
|
||||
queue.mark_finished(closure1)
|
||||
queue.mark_finished()
|
||||
self.assertTrue(queue.done())
|
||||
queue.wait()
|
||||
|
||||
def testProcessAtLeaseOnce(self):
|
||||
closure_queue = client._CoordinatedClosureQueue()
|
||||
closure_queue = client._CoordinatedClosureQueue(MockCancellationManager())
|
||||
labels = ['A', 'B', 'C', 'D', 'E']
|
||||
processed_count = collections.defaultdict(int)
|
||||
|
||||
@ -63,7 +75,7 @@ class CoordinatedClosureQueueTest(test.TestCase):
|
||||
closure_queue.put_back(closure)
|
||||
continue
|
||||
closure._function()
|
||||
closure_queue.mark_finished(closure)
|
||||
closure_queue.mark_finished()
|
||||
|
||||
def get_func(label):
|
||||
|
||||
@ -76,7 +88,8 @@ class CoordinatedClosureQueueTest(test.TestCase):
|
||||
return func
|
||||
|
||||
for label in labels:
|
||||
closure_queue.put(client.Closure(get_func(label)))
|
||||
closure_queue.put(
|
||||
client.Closure(get_func(label), MockCancellationManager()))
|
||||
t1 = threading.Thread(target=process_queue, daemon=True)
|
||||
t1.start()
|
||||
t2 = threading.Thread(target=process_queue, daemon=True)
|
||||
@ -93,7 +106,7 @@ class CoordinatedClosureQueueTest(test.TestCase):
|
||||
coord.join([t1, t2])
|
||||
|
||||
def testNotifyBeforeWait(self):
|
||||
closure_queue = client._CoordinatedClosureQueue()
|
||||
closure_queue = client._CoordinatedClosureQueue(MockCancellationManager())
|
||||
|
||||
def func():
|
||||
logging.info('func running')
|
||||
@ -102,10 +115,10 @@ class CoordinatedClosureQueueTest(test.TestCase):
|
||||
|
||||
def process_queue():
|
||||
with coord.stop_on_exception():
|
||||
closure = closure_queue.get()
|
||||
closure_queue.mark_finished(closure)
|
||||
closure_queue.get()
|
||||
closure_queue.mark_finished()
|
||||
|
||||
closure_queue.put(client.Closure(func))
|
||||
closure_queue.put(client.Closure(func, MockCancellationManager()))
|
||||
t = threading.Thread(target=process_queue)
|
||||
t.start()
|
||||
coord.join([t])
|
||||
@ -114,8 +127,30 @@ class CoordinatedClosureQueueTest(test.TestCase):
|
||||
# doesn't time out.
|
||||
closure_queue.wait()
|
||||
|
||||
def _assert_one_unblock_the_other(self, first_fn, second_fn):
|
||||
"""Asserts `second_fn` wouldn't return before `first_fn` is finished."""
|
||||
first_fn_done = threading.Event()
|
||||
second_fn_done = threading.Event()
|
||||
coord = coordinator.Coordinator(clean_stop_exception_types=[])
|
||||
|
||||
def wrapped_first_fn():
|
||||
with coord.stop_on_exception():
|
||||
self.assertFalse(second_fn_done.is_set())
|
||||
first_fn()
|
||||
first_fn_done.set()
|
||||
|
||||
self.assertFalse(first_fn_done.is_set())
|
||||
t = threading.Thread(target=wrapped_first_fn)
|
||||
t.start()
|
||||
|
||||
second_fn()
|
||||
self.assertTrue(first_fn_done.is_set())
|
||||
second_fn_done.set()
|
||||
|
||||
coord.join([t])
|
||||
|
||||
def testWaitRaiseErrorAfterMarkFailure(self):
|
||||
closure_queue = client._CoordinatedClosureQueue()
|
||||
closure_queue = client._CoordinatedClosureQueue(MockCancellationManager())
|
||||
closure_queue.put(self._create_closure())
|
||||
closure = closure_queue.get()
|
||||
|
||||
@ -126,22 +161,17 @@ class CoordinatedClosureQueueTest(test.TestCase):
|
||||
# all inflight closures are finished.
|
||||
|
||||
def mark_finished_fn():
|
||||
with coord.stop_on_exception():
|
||||
self.assertFalse(wait_finish_event.is_set())
|
||||
try:
|
||||
raise ValueError('Some error.')
|
||||
except ValueError as e:
|
||||
closure_queue.mark_failed(e, closure)
|
||||
wait_finish_event.wait()
|
||||
try:
|
||||
raise ValueError('Some error.')
|
||||
except ValueError as e:
|
||||
closure_queue.mark_failed(e)
|
||||
|
||||
t = threading.Thread(target=mark_finished_fn)
|
||||
t.start()
|
||||
def wait_fn():
|
||||
with self.assertRaises(ValueError):
|
||||
closure_queue.wait()
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
closure_queue.wait()
|
||||
wait_finish_event.set()
|
||||
self._assert_one_unblock_the_other(mark_finished_fn, wait_fn)
|
||||
|
||||
coord.join([t])
|
||||
self.assertTrue(closure_queue.done())
|
||||
|
||||
def _create_closure(self):
|
||||
@ -150,10 +180,10 @@ class CoordinatedClosureQueueTest(test.TestCase):
|
||||
def some_function():
|
||||
return 1.0
|
||||
|
||||
return client.Closure(some_function)
|
||||
return client.Closure(some_function, MockCancellationManager())
|
||||
|
||||
def _put_two_closures_and_get_one(self):
|
||||
closure_queue = client._CoordinatedClosureQueue()
|
||||
closure_queue = client._CoordinatedClosureQueue(MockCancellationManager())
|
||||
closure1 = self._create_closure()
|
||||
closure_queue.put(closure1)
|
||||
|
||||
@ -166,9 +196,9 @@ class CoordinatedClosureQueueTest(test.TestCase):
|
||||
return closure_queue, closure1, closure2
|
||||
|
||||
def testPutRaiseError(self):
|
||||
closure_queue, closure1, closure2 = self._put_two_closures_and_get_one()
|
||||
closure_queue, _, closure2 = self._put_two_closures_and_get_one()
|
||||
|
||||
closure_queue.mark_failed(ValueError(), closure1)
|
||||
closure_queue.mark_failed(ValueError())
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
closure_queue.put(self._create_closure())
|
||||
@ -185,9 +215,9 @@ class CoordinatedClosureQueueTest(test.TestCase):
|
||||
closure_queue.put(self._create_closure())
|
||||
|
||||
def testWaitRaiseError(self):
|
||||
closure_queue, closure1, closure2 = self._put_two_closures_and_get_one()
|
||||
closure_queue, _, closure2 = self._put_two_closures_and_get_one()
|
||||
|
||||
closure_queue.mark_failed(ValueError(), closure1)
|
||||
closure_queue.mark_failed(ValueError())
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
closure_queue.wait()
|
||||
@ -203,15 +233,22 @@ class CoordinatedClosureQueueTest(test.TestCase):
|
||||
closure_queue.wait()
|
||||
|
||||
def testDoneRaiseError(self):
|
||||
closure_queue, closure1, _ = self._put_two_closures_and_get_one()
|
||||
closure_queue.get()
|
||||
closure_queue, _, _ = self._put_two_closures_and_get_one()
|
||||
|
||||
self.assertFalse(closure_queue.done())
|
||||
closure_queue.mark_failed(ValueError(), closure1)
|
||||
closure_queue.mark_failed(ValueError())
|
||||
with self.assertRaises(ValueError):
|
||||
closure_queue.done()
|
||||
|
||||
def _test_error_reporting_and_cancel_flow(self, call_wait):
|
||||
def _set_error(self, closure_queue, closure, error):
|
||||
try:
|
||||
raise error
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
nest.map_structure(lambda x: x._set_error(e),
|
||||
closure._output_remote_values)
|
||||
closure_queue.mark_failed(e)
|
||||
|
||||
def _test_cancel_closure_when_error(self, call_wait):
|
||||
closure_queue, closure1, closure2 = self._put_two_closures_and_get_one()
|
||||
closure_queue.put(self._create_closure())
|
||||
closure_queue.get()
|
||||
@ -219,34 +256,37 @@ class CoordinatedClosureQueueTest(test.TestCase):
|
||||
self.assertEqual(closure_queue._inflight_closure_count, 2)
|
||||
|
||||
# Simulating closure1 fails.
|
||||
try:
|
||||
raise ValueError('Some error.')
|
||||
except ValueError as e:
|
||||
nest.map_structure(lambda x: x._set_error(e),
|
||||
closure1._output_remote_values)
|
||||
self.assertEqual(closure_queue._error_generation, 0) # pylint: disable=g-assert-in-except
|
||||
closure_queue.mark_failed(e, closure1)
|
||||
self.assertEqual(closure_queue._error_generation, 1)
|
||||
# At this moment, there are one inflight, nothing
|
||||
# in queue (because the ones in queue should have been removed and
|
||||
# cancelled).
|
||||
self.assertTrue(closure_queue._queue.empty())
|
||||
# Doesn't include out of generation closures.
|
||||
self._set_error(closure_queue, closure1, ValueError('Some error.'))
|
||||
|
||||
# At this moment, there are one inflight, one in queue.
|
||||
self.assertEqual(closure_queue._queue.qsize(), 1)
|
||||
self.assertEqual(closure_queue._inflight_closure_count, 1)
|
||||
|
||||
coord = coordinator.Coordinator(clean_stop_exception_types=[])
|
||||
closure3 = self._create_closure()
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Verifying `wait()` or `put()` raises even if one closure is in
|
||||
# flight.
|
||||
if call_wait:
|
||||
closure_queue.wait()
|
||||
else:
|
||||
closure_queue.put(closure3)
|
||||
# At this moment, there is one inflight, nothing in queue.
|
||||
def fake_cancellation():
|
||||
self._set_error(closure_queue, closure2,
|
||||
ValueError('Fake cancellation error.'))
|
||||
|
||||
def report_error():
|
||||
# It should not report the fake cancellation error.
|
||||
with self.assertRaisesRegex(ValueError, 'Some error.'):
|
||||
# Verifying `wait()` or `put()` raises even if one closure is in
|
||||
# flight.
|
||||
if call_wait:
|
||||
closure_queue.wait()
|
||||
else:
|
||||
closure_queue.put(closure3)
|
||||
|
||||
self._assert_one_unblock_the_other(fake_cancellation, report_error)
|
||||
|
||||
# Cancellation manager has been called.
|
||||
self.assertTrue(closure_queue._cancellation_mgr.cancelled)
|
||||
|
||||
# At this moment, there is zero inflight, nothing in queue.
|
||||
self.assertTrue(closure_queue._queue.empty())
|
||||
self.assertEqual(closure_queue._inflight_closure_count, 1)
|
||||
self.assertEqual(closure_queue._inflight_closure_count, 0)
|
||||
self.assertIsNone(closure_queue._error)
|
||||
|
||||
# This asserts that closure1 has errored.
|
||||
with self.assertRaisesRegex(ValueError, 'Some error.'):
|
||||
@ -260,107 +300,36 @@ class CoordinatedClosureQueueTest(test.TestCase):
|
||||
'function.'):
|
||||
closure3._fetch_output_remote_values()
|
||||
|
||||
# Closure2 is inflight, so it shouldn't be ready.
|
||||
# Closure2 was an inflight closure when it got cancelled.
|
||||
self.assertEqual(closure2._output_remote_values._status,
|
||||
client._RemoteValueStatus.NOT_READY)
|
||||
|
||||
# And `wait` should block because closure2 is not back yet.
|
||||
self.assertFalse(closure_queue.wait(timeout=20))
|
||||
|
||||
# Now let's assume that closure2 isn't successful due to worker preemption,
|
||||
# and now it's attempted to be put back, but ends up getting cancelled.
|
||||
self.assertEqual(closure2._error_generation, 0)
|
||||
self.assertEqual(closure_queue._error_generation, 1)
|
||||
closure_queue.put_back(closure2)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
client.FunctionRetryableError,
|
||||
'The corresponding function is cancelled. Please reschedule the '
|
||||
'function.'):
|
||||
client._RemoteValueStatus.READY)
|
||||
with self.assertRaisesRegex(ValueError, 'Fake cancellation error.'):
|
||||
closure2._fetch_output_remote_values()
|
||||
|
||||
# At this moment, there is nothing inflight, and the queue is also empty
|
||||
# (because closure2 should not be added back to the queue).
|
||||
self.assertTrue(closure_queue._queue.empty())
|
||||
self.assertEqual(closure_queue._inflight_closure_count, 0)
|
||||
# This asserts that the queue has a clear state.
|
||||
self.testBasic()
|
||||
|
||||
closure4 = self._create_closure()
|
||||
def testWaitRaiseErrorAfterCancelClosure(self):
|
||||
self._test_cancel_closure_when_error(call_wait=True)
|
||||
|
||||
e = threading.Event()
|
||||
|
||||
def get_fn():
|
||||
with coord.stop_on_exception():
|
||||
# This should end up getting closure4, not closure2, because closure2
|
||||
# has been cancelled and should not be got.
|
||||
closure_got = closure_queue.get()
|
||||
e.set()
|
||||
self.assertEqual(closure_got._error_generation, 1)
|
||||
self.assertEqual(closure_queue._error_generation, 1)
|
||||
self.assertIs(closure4, closure_got)
|
||||
self.assertIsNot(closure2, closure_got)
|
||||
|
||||
t = threading.Thread(target=get_fn)
|
||||
t.start()
|
||||
|
||||
time.sleep(10)
|
||||
|
||||
# Make sure `closure_got = closure_queue.get()` is unblocked as a result of
|
||||
# `closure_queue.put(closure4)`.
|
||||
self.assertFalse(e.is_set())
|
||||
closure_queue.put(closure4)
|
||||
self.assertTrue(e.wait())
|
||||
coord.join([t])
|
||||
|
||||
self.assertEqual(closure_queue._inflight_closure_count, 1)
|
||||
closure_queue.mark_finished(closure4)
|
||||
# The queue is now cleared and nothing inflight.
|
||||
self.assertEqual(closure_queue._inflight_closure_count, 0)
|
||||
closure_queue.wait()
|
||||
|
||||
def testWaitRaiseErrorAfterAnErrorIsReported(self):
|
||||
self._test_error_reporting_and_cancel_flow(call_wait=True)
|
||||
|
||||
def testPutRaiseErrorAfterAnErrorIsReported(self):
|
||||
self._test_error_reporting_and_cancel_flow(call_wait=False)
|
||||
def testPutRaiseErrorAfterCancelClosure(self):
|
||||
self._test_cancel_closure_when_error(call_wait=False)
|
||||
|
||||
def testStateIsRestoredAfterJoinIsCalled(self):
|
||||
closure_queue, closure1, closure2 = self._put_two_closures_and_get_one()
|
||||
closure_queue.get()
|
||||
self.assertEqual(closure_queue._inflight_closure_count, 2)
|
||||
closure_queue.mark_failed(ValueError('test error'), closure1)
|
||||
closure_queue, _, _ = self._put_two_closures_and_get_one()
|
||||
self.assertEqual(closure_queue._inflight_closure_count, 1)
|
||||
closure_queue.mark_failed(ValueError('test error'))
|
||||
with self.assertRaises(ValueError):
|
||||
closure_queue.put(self._create_closure())
|
||||
closure_queue.mark_failed(ValueError('test error'), closure2)
|
||||
|
||||
# closure2's error is previous generation so should not raise at this
|
||||
# following put, and _error should have been cleared.
|
||||
# Its error should have been cleared.
|
||||
self.assertIsNone(closure_queue._error)
|
||||
closure_queue.put(self._create_closure())
|
||||
self.assertIsNone(closure_queue._error)
|
||||
|
||||
def testStateIsRestoredAfterJoinIsCalled_WaitShouldReturn(self):
|
||||
closure_queue, closure1, closure2 = self._put_two_closures_and_get_one()
|
||||
closure_queue.put(self._create_closure())
|
||||
closure_queue.get() # got closure2
|
||||
self.assertFalse(closure_queue._queue.empty()) # still has closure3
|
||||
self.assertEqual(closure_queue._inflight_closure_count, 2) # closure1,2
|
||||
closure_queue.mark_failed(ValueError('test error'), closure1)
|
||||
self.assertTrue(closure_queue._queue.empty()) # closure3 cancelled
|
||||
self.assertEqual(closure_queue._inflight_closure_count, 1)
|
||||
with self.assertRaises(ValueError):
|
||||
closure_queue.wait() # reports error from closure1
|
||||
|
||||
# `wait` should block because closure2 is not back yet, even if closure2
|
||||
# was sent inflight before the error.
|
||||
self.assertFalse(closure_queue.wait(timeout=20))
|
||||
self.assertEqual(closure_queue._inflight_closure_count, 1)
|
||||
closure_queue.mark_finished(closure2)
|
||||
closure_queue.wait() # wait should pass immediately
|
||||
self.assertEqual(closure_queue._inflight_closure_count, 0)
|
||||
|
||||
def testThreadSafey(self):
|
||||
thread_count = 10
|
||||
queue = client._CoordinatedClosureQueue()
|
||||
queue = client._CoordinatedClosureQueue(MockCancellationManager())
|
||||
|
||||
# Each thread performs 20 queue actions: 10 are `put_back` and 10 are
|
||||
# `mark_finished`.
|
||||
@ -372,7 +341,7 @@ class CoordinatedClosureQueueTest(test.TestCase):
|
||||
if i % 2 == 0:
|
||||
queue.put_back(closure)
|
||||
else:
|
||||
queue.mark_finished(closure)
|
||||
queue.mark_finished()
|
||||
|
||||
threads = [threading.Thread(target=func) for i in range(thread_count)]
|
||||
for t in threads:
|
||||
|
@ -19,7 +19,10 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
import threading
|
||||
from absl import logging
|
||||
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import multi_worker_test_base
|
||||
from tensorflow.python.distribute import sharded_variable
|
||||
@ -40,6 +43,48 @@ from tensorflow.python.ops import variables
|
||||
from tensorflow.python.training.server_lib import ClusterSpec
|
||||
|
||||
|
||||
class ErrorReportingThread(threading.Thread):
|
||||
|
||||
error = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
assert "target" in kwargs
|
||||
target = kwargs["target"]
|
||||
|
||||
@functools.wraps(target)
|
||||
def wrapped_target(*args, **kwargs):
|
||||
try:
|
||||
return target(*args, **kwargs)
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
ErrorReportingThread.error = e
|
||||
|
||||
kwargs["target"] = wrapped_target
|
||||
super(ErrorReportingThread, self).__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class TestCaseWithErrorReportingThread(test.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls._threading_thread = threading.Thread
|
||||
threading.Thread = ErrorReportingThread
|
||||
super(TestCaseWithErrorReportingThread, cls).setUpClass()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
super(TestCaseWithErrorReportingThread, cls).tearDownClass()
|
||||
threading.Thread = cls._threading_thread
|
||||
|
||||
def setUp(self):
|
||||
ErrorReportingThread.error = None
|
||||
super(TestCaseWithErrorReportingThread, self).setUp()
|
||||
|
||||
def tearDown(self):
|
||||
super(TestCaseWithErrorReportingThread, self).tearDown()
|
||||
if ErrorReportingThread.error:
|
||||
raise ErrorReportingThread.error # pylint: disable=raising-bad-type
|
||||
|
||||
|
||||
def make_client(num_workers, num_ps):
|
||||
# TODO(rchao): Test the internal rpc_layer version.
|
||||
cluster_def = multi_worker_test_base.create_in_process_cluster(
|
||||
@ -52,7 +97,7 @@ def make_client(num_workers, num_ps):
|
||||
return parameter_server_client.ParameterServerClient(cluster_resolver)
|
||||
|
||||
|
||||
class ParameterServerClientTest(test.TestCase):
|
||||
class ParameterServerClientTest(TestCaseWithErrorReportingThread):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@ -304,7 +349,7 @@ class VariablePartitioningScopeTest(test.TestCase):
|
||||
self.assertEqual(var_sum, 10.0)
|
||||
|
||||
|
||||
class ErrorReportingTest(test.TestCase):
|
||||
class ErrorReportingTest(TestCaseWithErrorReportingThread):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@ -344,8 +389,16 @@ class ErrorReportingTest(test.TestCase):
|
||||
while True:
|
||||
self.client.schedule(self._normal_function)
|
||||
|
||||
def testScheduleRaiseErrorWithMultipleFailure(self):
|
||||
for _ in range(3):
|
||||
self.client.schedule(self._normal_function)
|
||||
self.client.schedule(self._error_function)
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
while True:
|
||||
self.client.schedule(self._error_function)
|
||||
self.client.join()
|
||||
|
||||
def testErrorWillbeCleared(self):
|
||||
self.skipTest("b/157597579")
|
||||
self.client.schedule(self._error_function)
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
self.client.join()
|
||||
@ -356,7 +409,7 @@ class ErrorReportingTest(test.TestCase):
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
self.client.join()
|
||||
|
||||
def testFutureReturnError(self):
|
||||
def testRemoteValueReturnError(self):
|
||||
result = self.client.schedule(self._error_function)
|
||||
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
|
Loading…
x
Reference in New Issue
Block a user