Add auto_restart to multi_process_runner

This helps creating fault tolerance test cases. MWMS currently requires an external system which brings back tasks that are down, otherwise the remaining workers may hang forever.

Ideally the remaining workers should error, which is what I'm working on. But it's beneficial to have test cases reflecting the current behavior since in many deployment, we do have a cluster management system that does the restart (e.g. k8s).

This also changes the behavior of dependence_on_chief. We used to terminate the cluster if chief exits when join() is called. Now with a watchdog thread, that happens immediately after the chief's terminate.

PiperOrigin-RevId: 324094287
Change-Id: If183d839168d71a2ffe3de10d1410e2116ea1f80
This commit is contained in:
Ran Chen 2020-07-30 15:19:06 -07:00 committed by TensorFlower Gardener
parent c111987612
commit 56bb1ccfaa
2 changed files with 285 additions and 87 deletions

View File

@ -67,7 +67,8 @@ except ImportError:
# exception stack trace info is stored in exc_info to pass on to parent process
# to be re-raised.
_ProcessStatusInfo = collections.namedtuple(
'_ProcessStatusInfo', ['is_successful', 'exc_info', 'return_value'])
'_ProcessStatusInfo',
['task_type', 'task_id', 'is_successful', 'exc_info', 'return_value'])
# Information returned from a successful MultiProcessRunner run.
MultiProcessRunnerResult = collections.namedtuple('MultiProcessRunnerResult',
@ -124,6 +125,8 @@ class MultiProcessRunner(object):
list_stdout=False,
use_dill_for_args=True,
daemon=False,
dependence_on_chief=True,
auto_restart=False,
args=None,
kwargs=None):
"""Creates a multi-process runner.
@ -161,6 +164,11 @@ class MultiProcessRunner(object):
can pickle more objects, but doesn't work with types in
`multiprocessing` library like `Mutex`.
daemon: Whether to start processes as daemons.
dependence_on_chief: Whether to terminates the cluster if the chief exits.
If auto_restart is True, it only terminates the cluster if the chief
exits with a zero exit code.
auto_restart: Whether to automatically restart processes that exit with
non-zero exit code.
args: Positional arguments to be sent to functions run on processes.
kwargs: Keyword arguments to be sent to functions run on processes.
@ -190,9 +198,10 @@ class MultiProcessRunner(object):
self._stream_stdout = stream_stdout
# TODO(rchao): Revisit list_stdout argument to consider other solution.
self._list_stdout = list_stdout
self._dependence_on_chief = True
self._dependence_on_chief = dependence_on_chief
self._use_dill_for_args = use_dill_for_args
self._daemon = daemon
self._auto_restart = auto_restart
self._args = args or ()
self._kwargs = kwargs or {}
@ -201,8 +210,15 @@ class MultiProcessRunner(object):
self._executing_eagerly = context.executing_eagerly()
self._joined = False
self._process_lock = threading.Lock()
# Guarded by self._process_lock.
self._processes = {}
self._outstanding_subprocess_count = 0
# Record which processes are terminated. Due to a bug in Python<3.7,
# terminated processes return 255 exit code, which should cause an exception
# in join().
# https://bugs.python.org/issue30589
# Guarded by self._process_lock.
self._terminated = set()
self._reading_threads = []
self._manager = manager()
@ -215,8 +231,7 @@ class MultiProcessRunner(object):
# safe.
self._streaming_queue = self._manager.Queue()
# This flag will be set to True once terminate_all() is called.
self._all_forced_terminated = False
self._watchdog_thread = None
def set_args(self, args=None, kwargs=None):
self._args = args or self._args
@ -281,7 +296,7 @@ class MultiProcessRunner(object):
daemon=self._daemon)
p.start()
self._processes[(task_type, task_id)] = p
self._outstanding_subprocess_count += 1
self._terminated.discard((task_type, task_id))
# For each subprocess, we dedicate a thread continuously reading lines
# from them.
@ -291,17 +306,26 @@ class MultiProcessRunner(object):
thread.start()
self._reading_threads.append(thread)
if self._watchdog_thread is None or not self._watchdog_thread.is_alive():
self._watchdog_thread = threading.Thread(target=self._process_watchdog)
self._watchdog_thread.start()
def start(self):
"""Starts processes, one for each task in `cluster_spec`.
Note that this is best effort by the applicable multiprocessing library,
and it may take up to seconds for a subprocess to be successfully started.
"""
if self._processes:
raise ValueError('MultiProcessRunner already started.')
for task_type, addresses in self._cluster_spec.items():
for task_id, _ in enumerate(addresses):
self._start_subprocess_and_reading_thread(task_type, task_id)
with self._process_lock:
if self._processes:
raise ValueError('MultiProcessRunner already started.')
if self._joined:
raise ValueError('cannot start new processes after'
'MultiProcessRunner.join() is called')
for task_type, addresses in self._cluster_spec.items():
for task_id, _ in enumerate(addresses):
self._start_subprocess_and_reading_thread(task_type, task_id)
# TODO(rchao): Remove the need of using SIGALRM if possible. At this time,
# without this the tests become very flaky.
@ -353,10 +377,14 @@ class MultiProcessRunner(object):
"""
if self._processes:
raise ValueError('MultiProcessRunner already started.')
for task_type, addresses in self._cluster_spec.items():
for task_id, _ in enumerate(addresses):
if not (task_type == as_task_type and task_id == as_task_id):
self._start_subprocess_and_reading_thread(task_type, task_id)
with self._process_lock:
if self._joined:
raise ValueError('cannot start new processes after'
'MultiProcessRunner.join() is called')
for task_type, addresses in self._cluster_spec.items():
for task_id, _ in enumerate(addresses):
if not (task_type == as_task_type and task_id == as_task_id):
self._start_subprocess_and_reading_thread(task_type, task_id)
_set_tf_config(as_task_type, as_task_id, self._cluster_spec,
self._rpc_layer)
@ -392,13 +420,17 @@ class MultiProcessRunner(object):
args: Optional positional arguments to be supplied in `proc_func`.
kwargs: Optional keyword arguments to be supplied in `proc_func`.
"""
self._start_subprocess_and_reading_thread(
task_type,
task_id,
cluster_spec=cluster_spec,
proc_func=proc_func,
args=args or (),
kwargs=kwargs or {})
with self._process_lock:
if self._joined:
raise ValueError('cannot start new processes after'
'MultiProcessRunner.join() is called')
self._start_subprocess_and_reading_thread(
task_type,
task_id,
cluster_spec=cluster_spec,
proc_func=proc_func,
args=args or (),
kwargs=kwargs or {})
def _queue_to_list(self, queue_to_convert):
"""Convert `queue.Queue` to `list`."""
@ -411,9 +443,17 @@ class MultiProcessRunner(object):
break
return list_to_return
def _get_process_statuses(self):
# One worker may have multiple statuses. We only keep the last one.
statuses = {}
for status in self._queue_to_list(self._process_status_queue):
statuses[(status.task_type, status.task_id)] = status
return statuses
def get_process_id(self, task_type, task_id):
"""Returns the subprocess id given the task type and task id."""
p = self._processes.get((task_type, task_id), None)
with self._process_lock:
p = self._processes.get((task_type, task_id), None)
return p.pid if p else None
def get_process_exit_code(self, task_type, task_id):
@ -430,22 +470,54 @@ class MultiProcessRunner(object):
KeyError: If the corresponding subprocess is not found with `task_type`
and `task_id`.
"""
p = self._processes[(task_type, task_id)]
with self._process_lock:
p = self._processes[(task_type, task_id)]
return p.exitcode if p else None
def _join_or_terminate(self, task_type, task_id, process, timeout):
"""Joins a process. If it times out, terminate all procsses."""
logging.info('joining %s-%d', task_type, task_id)
process.join(timeout)
# If exitcode is None, the process aren't terminated and this is a
# timeout.
if process.exitcode is None:
# Force termination to dump worker processes stack trace.
self.terminate_all(sig=signal.SIGTERM)
process_statuses = self._queue_to_list(self._process_status_queue)
raise SubprocessTimeoutError(
'%s-%d and possibly more subprocesses timed out.' %
(task_type, task_id), self._get_mpr_result(process_statuses))
def _process_watchdog(self):
"""Simulates a cluster management system.
- If auto_restart is True, it restarts processes that exit with a non-zero
exit code. Note that when join() times out it overrides auto_restart to
False.
- If dependence_on_chief is True, it terminates all processes once the chief
exits. If auto_restart is also True, it only terminates all processes if
the chief exit with a zero exit code, otherwise it restarts the chief.
This runs in self._watchdog_thread.
"""
while True:
time.sleep(1)
with self._process_lock:
chief = self._processes.get(('chief', 0), None)
# Terminate the cluster when _dependence_on_chief is True if either:
# - chief has exited with zero exit code.
# - chief has exited with non-zero exit code and self._auto_restart is
# False.
if chief and self._dependence_on_chief and chief.exitcode is not None:
if chief.exitcode == 0 or (not self._auto_restart):
for p in self._processes.values():
# Give other processes a chance to exit on their own.
p.join(timeout=3)
self._terminate_all()
for p in self._processes.values():
p.join()
return
# Auto restart failed processes if self._auto_restart is True.
if self._auto_restart:
has_failure = False
for (task_type, task_id), p in self._processes.items():
if p.exitcode is not None and p.exitcode != 0:
has_failure = True
logging.info('Restarting failed %s-%d', task_type, task_id)
self._start_subprocess_and_reading_thread(task_type, task_id)
if has_failure:
continue
# Exit the thread if all processes have exited at this point.
if all(p.exitcode is not None for p in self._processes.values()):
return
def join(self, timeout=_DEFAULT_TIMEOUT_SEC):
"""Joins all the processes with timeout.
@ -489,41 +561,40 @@ class MultiProcessRunner(object):
cases.
Exception: if there is an Exception propagated from any subprocess.
"""
if self._joined:
raise ValueError("MultiProcessRunner can't be joined twice.")
self._joined = True
with self._process_lock:
if self._joined:
raise ValueError("MultiProcessRunner can't be joined twice.")
self._joined = True
chief = self._processes.get(('chief', 0), None)
if self._dependence_on_chief and chief:
self._join_or_terminate('chief', 0, chief, timeout)
# Give other processes a chance to exit on their own.
for p in self._processes.values():
p.join(timeout=3)
self.terminate_all()
else:
for (task_type, task_id), p in self._processes.items():
self._join_or_terminate(task_type, task_id, p, timeout)
self._watchdog_thread.join(timeout)
if self._watchdog_thread.is_alive():
# Timeout. Force termination to dump worker processes stack trace.
with self._process_lock:
self._auto_restart = False
self.terminate_all(sig=signal.SIGTERM)
self._watchdog_thread.join()
process_statuses = self._get_process_statuses()
raise SubprocessTimeoutError('one or more subprocesses timed out.',
self._get_mpr_result(process_statuses))
for (task_type, task_id), p in self._processes.items():
logging.info('%s-%d exit code: %s', task_type, task_id, p.exitcode)
process_statuses = self._queue_to_list(self._process_status_queue)
for process_status in process_statuses:
process_statuses = self._get_process_statuses()
for process_status in process_statuses.values():
assert isinstance(process_status, _ProcessStatusInfo)
if not process_status.is_successful:
six.reraise(*process_status.exc_info)
# Checking all the processes that are expected to exit properly.
for (task_type, task_id), p in self._processes.items():
if self._dependence_on_chief and chief and task_type != 'chief':
# If _dependence_on_chief, other processes may have been
# forced-terminated, which is expected.
continue
# Successfully exiting process has exit code 0.
if p.exitcode is None or p.exitcode > 0:
# Successfully exiting process has exit code 0. We ignore processes that
# are terminated.
assert p.exitcode is not None
if (p.exitcode > 0 and (task_type, task_id) not in self._terminated):
raise UnexpectedSubprocessExitError(
'Subprocess %s-%d exited with exit code %d. See logs for details.' %
(task_type, task_id, p.exitcode),
'Subprocess %s-%d exited with exit code %s. See logs for details.'
% (task_type, task_id, p.exitcode),
self._get_mpr_result(process_statuses))
logging.info('Joining log reading threads.')
@ -539,34 +610,60 @@ class MultiProcessRunner(object):
def _get_mpr_result(self, process_statuses):
stdout = self._queue_to_list(self._streaming_queue)
return_values = []
for process_status in process_statuses:
for process_status in process_statuses.values():
if process_status.return_value is not None:
return_values.append(process_status.return_value)
return MultiProcessRunnerResult(stdout=stdout, return_value=return_values)
def terminate(self, task_type, task_id):
"""Terminates the process with `task_type` and `task_id`."""
p = self._processes.get((task_type, task_id), None)
if p is None:
raise ValueError('{}-{} does not exist'.format(task_type, task_id))
# TODO(crccw): change to use Process.terminate() as well.
self._parent_to_sub_queue.put('terminate {} {}'.format(task_type, task_id))
p.join()
"""Terminates the process with `task_type` and `task_id`.
If auto_retart=True, the terminated task will be restarted unless the chief
has already exited with zero exit code.
Args:
task_type: the task type.
task_id: the task id.
"""
with self._process_lock:
p = self._processes.get((task_type, task_id), None)
if p is None:
raise ValueError('{}-{} does not exist'.format(task_type, task_id))
self._terminated.add((task_type, task_id))
# TODO(crccw): change to use Process.terminate() as well.
self._parent_to_sub_queue.put('terminate {} {}'.format(
task_type, task_id))
p.join()
def _terminate_all(self, sig=None):
"""Terminates all subprocesses.
The caller is required to hold self._process_lock.
Args:
sig: the signal used to terminate the process. The default is SIGKILL.
"""
def terminate_all(self, sig=None):
"""Terminates all subprocesses."""
# Use SIGKILL as default. In systems where that's unavailable such as
# windows, use SIGTERM.
sig = sig or getattr(signal, 'SIGKILL', signal.SIGTERM)
for (task_type, task_id), p in self._processes.items():
if p.exitcode is not None:
continue
try:
os.kill(p.pid, sig)
self._terminated.add((task_type, task_id))
logging.info('%s-%d terminated with signal %r.', task_type, task_id,
sig)
except ProcessLookupError:
logging.info('Attempting to kill %s-%d but it does not exist.',
task_type, task_id)
self._all_forced_terminated = True
def terminate_all(self, sig=None):
"""Terminates all subprocesses."""
with self._process_lock:
self._terminate_all(sig)
class _Process(multi_process_lib.Process):
@ -625,11 +722,13 @@ class _ProcFunc(object):
time.sleep(0.1)
self._resources.process_status_queue.put(
_ProcessStatusInfo(
task_type=task_type,
task_id=task_id,
is_successful=True,
exc_info=None,
return_value=None))
# `os._exit(0)` is used to more reliably terminate a subprocess.
os._exit(0) # pylint: disable=protected-access
# `os._exit(1)` is used to more reliably terminate a subprocess.
os._exit(1) # pylint: disable=protected-access
def _close_streaming(self):
"""Close stdout, stderr and streaming pipe.
@ -685,7 +784,8 @@ class _ProcFunc(object):
v2_compat.enable_v2_behavior()
with self._runtime_mode(test_env.executing_eagerly):
info = _run_contained(proc_func, args, kwargs)
info = _run_contained(test_env.task_type, test_env.task_id, proc_func,
args, kwargs)
self._resources.process_status_queue.put(info)
# Re-raise the exception in addition to reporting it to the parent
@ -774,7 +874,7 @@ class MultiProcessPoolRunner(object):
task_type,
task_id,
proc_func=_pool_runner_worker,
args=(initializer, conn2))
args=(task_type, task_id, initializer, conn2))
def run(self, proc_func, args=None, kwargs=None):
"""Runs `proc_func` with `args` and `kwargs` on all jobs.
@ -819,7 +919,7 @@ class MultiProcessPoolRunner(object):
return return_values
def _pool_runner_worker(initializer, conn):
def _pool_runner_worker(task_type, task_id, initializer, conn):
"""Function that runs on the workers in a pool.
It listens for callables to run and returns the result until `conn` is closed.
@ -827,8 +927,10 @@ def _pool_runner_worker(initializer, conn):
`conn`.
Args:
initializer: A callable to execute during startup.
conn: A multiprocessing.Connection object to listen for tasks and send
task_type: the task type.
task_id: the task index.
initializer: a callable to execute during startup.
conn: a multiprocessing.Connection object to listen for tasks and send
results.
"""
if initializer:
@ -840,22 +942,24 @@ def _pool_runner_worker(initializer, conn):
except EOFError:
break
proc_func = dill.loads(proc_func)
info = _run_contained(proc_func, args, kwargs)
info = _run_contained(task_type, task_id, proc_func, args, kwargs)
sys.stdout.flush()
sys.stderr.flush()
conn.send(info)
def _run_contained(proc_func, args, kwargs):
def _run_contained(task_type, task_id, proc_func, args, kwargs):
"""Runs `proc_func` with `args` and `kwargs`.
The function returns _ProcessStatusInfo which captures the return value and
the exception.
Args:
proc_func: The function to be run.
args: Optional positional arguments to be supplied in `proc_func`.
kwargs: Optional keyword arguments to be supplied in `proc_func`.
task_type: the task type.
task_id: the task index.
proc_func: the function to be run.
args: optional positional arguments to be supplied in `proc_func`.
kwargs: optional keyword arguments to be supplied in `proc_func`.
Returns:
a _ProcessStatusInfo.
@ -868,6 +972,8 @@ def _run_contained(proc_func, args, kwargs):
return_value = proc_func(*args, **kwargs)
is_successful = True
return _ProcessStatusInfo(
task_type=task_type,
task_id=task_id,
is_successful=is_successful,
exc_info=exc_info,
return_value=return_value)
@ -877,6 +983,8 @@ def _run_contained(proc_func, args, kwargs):
except Exception: # pylint: disable=broad-except
exc_info = sys.exc_info()
return _ProcessStatusInfo(
task_type=task_type,
task_id=task_id,
is_successful=is_successful,
exc_info=exc_info,
return_value=return_value)

View File

@ -156,11 +156,8 @@ class MultiProcessRunnerTest(test.TestCase):
mpr.start()
time.sleep(5)
mpr.terminate('worker', 0)
with self.assertRaises(
multi_process_runner.UnexpectedSubprocessExitError) as cm:
mpr.join()
std_stream_results = cm.exception.mpr_result.stdout
std_stream_results = mpr.join().stdout
# Worker 0 is terminated in the middle, so it should not have iteration 9
# printed.
@ -388,6 +385,99 @@ class MultiProcessRunnerTest(test.TestCase):
'Subprocess worker-0 exited with exit code 10'):
mpr.join()
def test_auto_restart(self):
def proc_func(counter):
counter.value += 1
if counter.value == 1:
raise ValueError
manager = multi_process_runner.manager()
counter = manager.Value(int, 0)
mpr = multi_process_runner.MultiProcessRunner(
proc_func,
multi_worker_test_base.create_cluster_spec(num_workers=1),
args=(counter,),
auto_restart=True)
mpr.start()
mpr.join()
self.assertEqual(counter.value, 2)
def test_auto_restart_and_timeout(self):
def proc_func():
time.sleep(1)
raise ValueError
mpr = multi_process_runner.MultiProcessRunner(
proc_func,
multi_worker_test_base.create_cluster_spec(num_workers=1),
auto_restart=True)
mpr.start()
with self.assertRaises(multi_process_runner.SubprocessTimeoutError):
mpr.join(timeout=10)
def test_auto_restart_and_chief(self):
# If the chief has exited with zero exit code, auto restart should stop
# restarting other tasks even if they fail.
def proc_func():
time.sleep(1)
if multi_worker_test_base.get_task_type() != 'chief':
raise ValueError
manager = multi_process_runner.manager()
mpr = multi_process_runner.MultiProcessRunner(
proc_func,
multi_worker_test_base.create_cluster_spec(
has_chief=True, num_workers=1),
auto_restart=True)
mpr.start()
with self.assertRaises(ValueError):
mpr.join(timeout=10)
def test_auto_restart_failure_immediate_after_restart(self):
# Test the case when worker-0 fails immediately after worker-1 restarts.
def proc_func():
time.sleep(5)
mpr = multi_process_runner.MultiProcessRunner(
proc_func,
multi_worker_test_base.create_cluster_spec(
has_chief=False, num_workers=2),
auto_restart=True)
mpr.start()
pid = mpr.get_process_id('worker', 1)
mpr.terminate('worker', 1)
while mpr.get_process_id('worker', 1) == pid:
time.sleep(0.1)
mpr.terminate('worker', 0)
mpr.join(timeout=20)
def test_auto_restart_terminate(self):
# Tasks terminated by the user should also be restarted.
def proc_func(counter):
counter.value += 1
if counter.value == 1:
time.sleep(100)
manager = multi_process_runner.manager()
counter = manager.Value(int, 0)
mpr = multi_process_runner.MultiProcessRunner(
proc_func,
multi_worker_test_base.create_cluster_spec(
has_chief=False, num_workers=1),
args=(counter,),
auto_restart=True)
mpr.start()
time.sleep(3)
mpr.terminate('worker', 0)
mpr.join(timeout=20)
self.assertEqual(counter.value, 2)
class MultiProcessPoolRunnerTest(test.TestCase):