Add auto_restart to multi_process_runner
This helps creating fault tolerance test cases. MWMS currently requires an external system which brings back tasks that are down, otherwise the remaining workers may hang forever. Ideally the remaining workers should error, which is what I'm working on. But it's beneficial to have test cases reflecting the current behavior since in many deployment, we do have a cluster management system that does the restart (e.g. k8s). This also changes the behavior of dependence_on_chief. We used to terminate the cluster if chief exits when join() is called. Now with a watchdog thread, that happens immediately after the chief's terminate. PiperOrigin-RevId: 324682642 Change-Id: I56ce27658298916d1ddd4507b90b79db0a2d4673
This commit is contained in:
parent
6b3990c84e
commit
b58a8717b1
@ -67,7 +67,8 @@ except ImportError:
|
|||||||
# exception stack trace info is stored in exc_info to pass on to parent process
|
# exception stack trace info is stored in exc_info to pass on to parent process
|
||||||
# to be re-raised.
|
# to be re-raised.
|
||||||
_ProcessStatusInfo = collections.namedtuple(
|
_ProcessStatusInfo = collections.namedtuple(
|
||||||
'_ProcessStatusInfo', ['is_successful', 'exc_info', 'return_value'])
|
'_ProcessStatusInfo',
|
||||||
|
['task_type', 'task_id', 'is_successful', 'exc_info', 'return_value'])
|
||||||
|
|
||||||
# Information returned from a successful MultiProcessRunner run.
|
# Information returned from a successful MultiProcessRunner run.
|
||||||
MultiProcessRunnerResult = collections.namedtuple('MultiProcessRunnerResult',
|
MultiProcessRunnerResult = collections.namedtuple('MultiProcessRunnerResult',
|
||||||
@ -97,6 +98,11 @@ Resources = collections.namedtuple('Resources', [
|
|||||||
# "medium" timeout of the test runs.
|
# "medium" timeout of the test runs.
|
||||||
_DEFAULT_TIMEOUT_SEC = 200
|
_DEFAULT_TIMEOUT_SEC = 200
|
||||||
|
|
||||||
|
# The timeout in seconds to wait to force kill a child process. When a child
|
||||||
|
# process times out we first try to SIGTERM it so that it has a chance to dump
|
||||||
|
# stacktraces. However dumping stacktrace can take a long time.
|
||||||
|
_FORCE_KILL_WAIT_SEC = 30
|
||||||
|
|
||||||
|
|
||||||
class MultiProcessRunner(object):
|
class MultiProcessRunner(object):
|
||||||
"""A utility class to start multiple processes to simulate a cluster.
|
"""A utility class to start multiple processes to simulate a cluster.
|
||||||
@ -124,6 +130,8 @@ class MultiProcessRunner(object):
|
|||||||
list_stdout=False,
|
list_stdout=False,
|
||||||
use_dill_for_args=True,
|
use_dill_for_args=True,
|
||||||
daemon=False,
|
daemon=False,
|
||||||
|
dependence_on_chief=True,
|
||||||
|
auto_restart=False,
|
||||||
args=None,
|
args=None,
|
||||||
kwargs=None):
|
kwargs=None):
|
||||||
"""Creates a multi-process runner.
|
"""Creates a multi-process runner.
|
||||||
@ -161,6 +169,11 @@ class MultiProcessRunner(object):
|
|||||||
can pickle more objects, but doesn't work with types in
|
can pickle more objects, but doesn't work with types in
|
||||||
`multiprocessing` library like `Mutex`.
|
`multiprocessing` library like `Mutex`.
|
||||||
daemon: Whether to start processes as daemons.
|
daemon: Whether to start processes as daemons.
|
||||||
|
dependence_on_chief: Whether to terminates the cluster if the chief exits.
|
||||||
|
If auto_restart is True, it only terminates the cluster if the chief
|
||||||
|
exits with a zero exit code.
|
||||||
|
auto_restart: Whether to automatically restart processes that exit with
|
||||||
|
non-zero exit code.
|
||||||
args: Positional arguments to be sent to functions run on processes.
|
args: Positional arguments to be sent to functions run on processes.
|
||||||
kwargs: Keyword arguments to be sent to functions run on processes.
|
kwargs: Keyword arguments to be sent to functions run on processes.
|
||||||
|
|
||||||
@ -190,9 +203,10 @@ class MultiProcessRunner(object):
|
|||||||
self._stream_stdout = stream_stdout
|
self._stream_stdout = stream_stdout
|
||||||
# TODO(rchao): Revisit list_stdout argument to consider other solution.
|
# TODO(rchao): Revisit list_stdout argument to consider other solution.
|
||||||
self._list_stdout = list_stdout
|
self._list_stdout = list_stdout
|
||||||
self._dependence_on_chief = True
|
self._dependence_on_chief = dependence_on_chief
|
||||||
self._use_dill_for_args = use_dill_for_args
|
self._use_dill_for_args = use_dill_for_args
|
||||||
self._daemon = daemon
|
self._daemon = daemon
|
||||||
|
self._auto_restart = auto_restart
|
||||||
self._args = args or ()
|
self._args = args or ()
|
||||||
self._kwargs = kwargs or {}
|
self._kwargs = kwargs or {}
|
||||||
|
|
||||||
@ -201,8 +215,15 @@ class MultiProcessRunner(object):
|
|||||||
self._executing_eagerly = context.executing_eagerly()
|
self._executing_eagerly = context.executing_eagerly()
|
||||||
|
|
||||||
self._joined = False
|
self._joined = False
|
||||||
|
self._process_lock = threading.Lock()
|
||||||
|
# Guarded by self._process_lock.
|
||||||
self._processes = {}
|
self._processes = {}
|
||||||
self._outstanding_subprocess_count = 0
|
# Record which processes are terminated. Due to a bug in Python<3.7,
|
||||||
|
# terminated processes return 255 exit code, which should cause an exception
|
||||||
|
# in join().
|
||||||
|
# https://bugs.python.org/issue30589
|
||||||
|
# Guarded by self._process_lock.
|
||||||
|
self._terminated = set()
|
||||||
self._reading_threads = []
|
self._reading_threads = []
|
||||||
|
|
||||||
self._manager = manager()
|
self._manager = manager()
|
||||||
@ -215,8 +236,7 @@ class MultiProcessRunner(object):
|
|||||||
# safe.
|
# safe.
|
||||||
self._streaming_queue = self._manager.Queue()
|
self._streaming_queue = self._manager.Queue()
|
||||||
|
|
||||||
# This flag will be set to True once terminate_all() is called.
|
self._watchdog_thread = None
|
||||||
self._all_forced_terminated = False
|
|
||||||
|
|
||||||
def set_args(self, args=None, kwargs=None):
|
def set_args(self, args=None, kwargs=None):
|
||||||
self._args = args or self._args
|
self._args = args or self._args
|
||||||
@ -281,7 +301,7 @@ class MultiProcessRunner(object):
|
|||||||
daemon=self._daemon)
|
daemon=self._daemon)
|
||||||
p.start()
|
p.start()
|
||||||
self._processes[(task_type, task_id)] = p
|
self._processes[(task_type, task_id)] = p
|
||||||
self._outstanding_subprocess_count += 1
|
self._terminated.discard((task_type, task_id))
|
||||||
|
|
||||||
# For each subprocess, we dedicate a thread continuously reading lines
|
# For each subprocess, we dedicate a thread continuously reading lines
|
||||||
# from them.
|
# from them.
|
||||||
@ -291,14 +311,23 @@ class MultiProcessRunner(object):
|
|||||||
thread.start()
|
thread.start()
|
||||||
self._reading_threads.append(thread)
|
self._reading_threads.append(thread)
|
||||||
|
|
||||||
|
if self._watchdog_thread is None or not self._watchdog_thread.is_alive():
|
||||||
|
self._watchdog_thread = threading.Thread(target=self._process_watchdog)
|
||||||
|
self._watchdog_thread.start()
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
"""Starts processes, one for each task in `cluster_spec`.
|
"""Starts processes, one for each task in `cluster_spec`.
|
||||||
|
|
||||||
Note that this is best effort by the applicable multiprocessing library,
|
Note that this is best effort by the applicable multiprocessing library,
|
||||||
and it may take up to seconds for a subprocess to be successfully started.
|
and it may take up to seconds for a subprocess to be successfully started.
|
||||||
"""
|
"""
|
||||||
|
with self._process_lock:
|
||||||
if self._processes:
|
if self._processes:
|
||||||
raise ValueError('MultiProcessRunner already started.')
|
raise ValueError('MultiProcessRunner already started.')
|
||||||
|
if self._joined:
|
||||||
|
raise ValueError('cannot start new processes after'
|
||||||
|
'MultiProcessRunner.join() is called')
|
||||||
|
|
||||||
for task_type, addresses in self._cluster_spec.items():
|
for task_type, addresses in self._cluster_spec.items():
|
||||||
for task_id, _ in enumerate(addresses):
|
for task_id, _ in enumerate(addresses):
|
||||||
self._start_subprocess_and_reading_thread(task_type, task_id)
|
self._start_subprocess_and_reading_thread(task_type, task_id)
|
||||||
@ -353,6 +382,10 @@ class MultiProcessRunner(object):
|
|||||||
"""
|
"""
|
||||||
if self._processes:
|
if self._processes:
|
||||||
raise ValueError('MultiProcessRunner already started.')
|
raise ValueError('MultiProcessRunner already started.')
|
||||||
|
with self._process_lock:
|
||||||
|
if self._joined:
|
||||||
|
raise ValueError('cannot start new processes after'
|
||||||
|
'MultiProcessRunner.join() is called')
|
||||||
for task_type, addresses in self._cluster_spec.items():
|
for task_type, addresses in self._cluster_spec.items():
|
||||||
for task_id, _ in enumerate(addresses):
|
for task_id, _ in enumerate(addresses):
|
||||||
if not (task_type == as_task_type and task_id == as_task_id):
|
if not (task_type == as_task_type and task_id == as_task_id):
|
||||||
@ -392,6 +425,10 @@ class MultiProcessRunner(object):
|
|||||||
args: Optional positional arguments to be supplied in `proc_func`.
|
args: Optional positional arguments to be supplied in `proc_func`.
|
||||||
kwargs: Optional keyword arguments to be supplied in `proc_func`.
|
kwargs: Optional keyword arguments to be supplied in `proc_func`.
|
||||||
"""
|
"""
|
||||||
|
with self._process_lock:
|
||||||
|
if self._joined:
|
||||||
|
raise ValueError('cannot start new processes after'
|
||||||
|
'MultiProcessRunner.join() is called')
|
||||||
self._start_subprocess_and_reading_thread(
|
self._start_subprocess_and_reading_thread(
|
||||||
task_type,
|
task_type,
|
||||||
task_id,
|
task_id,
|
||||||
@ -411,8 +448,16 @@ class MultiProcessRunner(object):
|
|||||||
break
|
break
|
||||||
return list_to_return
|
return list_to_return
|
||||||
|
|
||||||
|
def _get_process_statuses(self):
|
||||||
|
# One worker may have multiple statuses. We only keep the last one.
|
||||||
|
statuses = {}
|
||||||
|
for status in self._queue_to_list(self._process_status_queue):
|
||||||
|
statuses[(status.task_type, status.task_id)] = status
|
||||||
|
return statuses
|
||||||
|
|
||||||
def get_process_id(self, task_type, task_id):
|
def get_process_id(self, task_type, task_id):
|
||||||
"""Returns the subprocess id given the task type and task id."""
|
"""Returns the subprocess id given the task type and task id."""
|
||||||
|
with self._process_lock:
|
||||||
p = self._processes.get((task_type, task_id), None)
|
p = self._processes.get((task_type, task_id), None)
|
||||||
return p.pid if p else None
|
return p.pid if p else None
|
||||||
|
|
||||||
@ -430,22 +475,54 @@ class MultiProcessRunner(object):
|
|||||||
KeyError: If the corresponding subprocess is not found with `task_type`
|
KeyError: If the corresponding subprocess is not found with `task_type`
|
||||||
and `task_id`.
|
and `task_id`.
|
||||||
"""
|
"""
|
||||||
|
with self._process_lock:
|
||||||
p = self._processes[(task_type, task_id)]
|
p = self._processes[(task_type, task_id)]
|
||||||
return p.exitcode if p else None
|
return p.exitcode if p else None
|
||||||
|
|
||||||
def _join_or_terminate(self, task_type, task_id, process, timeout):
|
def _process_watchdog(self):
|
||||||
"""Joins a process. If it times out, terminate all procsses."""
|
"""Simulates a cluster management system.
|
||||||
logging.info('joining %s-%d', task_type, task_id)
|
|
||||||
process.join(timeout)
|
- If auto_restart is True, it restarts processes that exit with a non-zero
|
||||||
# If exitcode is None, the process aren't terminated and this is a
|
exit code. Note that when join() times out it overrides auto_restart to
|
||||||
# timeout.
|
False.
|
||||||
if process.exitcode is None:
|
- If dependence_on_chief is True, it terminates all processes once the chief
|
||||||
# Force termination to dump worker processes stack trace.
|
exits. If auto_restart is also True, it only terminates all processes if
|
||||||
self.terminate_all(sig=signal.SIGTERM)
|
the chief exit with a zero exit code, otherwise it restarts the chief.
|
||||||
process_statuses = self._queue_to_list(self._process_status_queue)
|
|
||||||
raise SubprocessTimeoutError(
|
This runs in self._watchdog_thread.
|
||||||
'%s-%d and possibly more subprocesses timed out.' %
|
"""
|
||||||
(task_type, task_id), self._get_mpr_result(process_statuses))
|
while True:
|
||||||
|
time.sleep(1)
|
||||||
|
with self._process_lock:
|
||||||
|
chief = self._processes.get(('chief', 0), None)
|
||||||
|
# Terminate the cluster when _dependence_on_chief is True if either:
|
||||||
|
# - chief has exited with zero exit code.
|
||||||
|
# - chief has exited with non-zero exit code and self._auto_restart is
|
||||||
|
# False.
|
||||||
|
if chief and self._dependence_on_chief and chief.exitcode is not None:
|
||||||
|
if chief.exitcode == 0 or (not self._auto_restart):
|
||||||
|
for p in self._processes.values():
|
||||||
|
# Give other processes a chance to exit on their own.
|
||||||
|
p.join(timeout=3)
|
||||||
|
self._terminate_all()
|
||||||
|
for p in self._processes.values():
|
||||||
|
p.join()
|
||||||
|
return
|
||||||
|
|
||||||
|
# Auto restart failed processes if self._auto_restart is True.
|
||||||
|
if self._auto_restart:
|
||||||
|
has_failure = False
|
||||||
|
for (task_type, task_id), p in self._processes.items():
|
||||||
|
if p.exitcode is not None and p.exitcode != 0:
|
||||||
|
has_failure = True
|
||||||
|
logging.info('Restarting failed %s-%d', task_type, task_id)
|
||||||
|
self._start_subprocess_and_reading_thread(task_type, task_id)
|
||||||
|
if has_failure:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Exit the thread if all processes have exited at this point.
|
||||||
|
if all(p.exitcode is not None for p in self._processes.values()):
|
||||||
|
return
|
||||||
|
|
||||||
def join(self, timeout=_DEFAULT_TIMEOUT_SEC):
|
def join(self, timeout=_DEFAULT_TIMEOUT_SEC):
|
||||||
"""Joins all the processes with timeout.
|
"""Joins all the processes with timeout.
|
||||||
@ -489,41 +566,48 @@ class MultiProcessRunner(object):
|
|||||||
cases.
|
cases.
|
||||||
Exception: if there is an Exception propagated from any subprocess.
|
Exception: if there is an Exception propagated from any subprocess.
|
||||||
"""
|
"""
|
||||||
|
with self._process_lock:
|
||||||
if self._joined:
|
if self._joined:
|
||||||
raise ValueError("MultiProcessRunner can't be joined twice.")
|
raise ValueError("MultiProcessRunner can't be joined twice.")
|
||||||
self._joined = True
|
self._joined = True
|
||||||
|
|
||||||
chief = self._processes.get(('chief', 0), None)
|
self._watchdog_thread.join(timeout)
|
||||||
if self._dependence_on_chief and chief:
|
if self._watchdog_thread.is_alive():
|
||||||
self._join_or_terminate('chief', 0, chief, timeout)
|
# Timeout. Force termination to dump worker processes stack trace.
|
||||||
# Give other processes a chance to exit on their own.
|
with self._process_lock:
|
||||||
for p in self._processes.values():
|
self._auto_restart = False
|
||||||
p.join(timeout=3)
|
logging.error('Timeout when joining for child processes. Terminating...')
|
||||||
|
self.terminate_all(sig=signal.SIGTERM)
|
||||||
|
# Wait for the processes to terminate by themselves first, so they have a
|
||||||
|
# chance to dump stacktraces. After _FORCE_KILL_WAIT_SEC, we SIGKILL them.
|
||||||
|
self._watchdog_thread.join(_FORCE_KILL_WAIT_SEC)
|
||||||
|
if self._watchdog_thread.is_alive():
|
||||||
|
logging.error('Timeout when waiting for child processes to '
|
||||||
|
'print stacktrace. Sending SIGKILL...')
|
||||||
self.terminate_all()
|
self.terminate_all()
|
||||||
else:
|
self._watchdog_thread.join()
|
||||||
for (task_type, task_id), p in self._processes.items():
|
process_statuses = self._get_process_statuses()
|
||||||
self._join_or_terminate(task_type, task_id, p, timeout)
|
raise SubprocessTimeoutError('one or more subprocesses timed out.',
|
||||||
|
self._get_mpr_result(process_statuses))
|
||||||
|
|
||||||
for (task_type, task_id), p in self._processes.items():
|
for (task_type, task_id), p in self._processes.items():
|
||||||
logging.info('%s-%d exit code: %s', task_type, task_id, p.exitcode)
|
logging.info('%s-%d exit code: %s', task_type, task_id, p.exitcode)
|
||||||
|
|
||||||
process_statuses = self._queue_to_list(self._process_status_queue)
|
process_statuses = self._get_process_statuses()
|
||||||
for process_status in process_statuses:
|
for process_status in process_statuses.values():
|
||||||
assert isinstance(process_status, _ProcessStatusInfo)
|
assert isinstance(process_status, _ProcessStatusInfo)
|
||||||
if not process_status.is_successful:
|
if not process_status.is_successful:
|
||||||
six.reraise(*process_status.exc_info)
|
six.reraise(*process_status.exc_info)
|
||||||
|
|
||||||
# Checking all the processes that are expected to exit properly.
|
# Checking all the processes that are expected to exit properly.
|
||||||
for (task_type, task_id), p in self._processes.items():
|
for (task_type, task_id), p in self._processes.items():
|
||||||
if self._dependence_on_chief and chief and task_type != 'chief':
|
# Successfully exiting process has exit code 0. We ignore processes that
|
||||||
# If _dependence_on_chief, other processes may have been
|
# are terminated.
|
||||||
# forced-terminated, which is expected.
|
assert p.exitcode is not None
|
||||||
continue
|
if (p.exitcode > 0 and (task_type, task_id) not in self._terminated):
|
||||||
# Successfully exiting process has exit code 0.
|
|
||||||
if p.exitcode is None or p.exitcode > 0:
|
|
||||||
raise UnexpectedSubprocessExitError(
|
raise UnexpectedSubprocessExitError(
|
||||||
'Subprocess %s-%d exited with exit code %d. See logs for details.' %
|
'Subprocess %s-%d exited with exit code %s. See logs for details.'
|
||||||
(task_type, task_id, p.exitcode),
|
% (task_type, task_id, p.exitcode),
|
||||||
self._get_mpr_result(process_statuses))
|
self._get_mpr_result(process_statuses))
|
||||||
|
|
||||||
logging.info('Joining log reading threads.')
|
logging.info('Joining log reading threads.')
|
||||||
@ -539,34 +623,60 @@ class MultiProcessRunner(object):
|
|||||||
def _get_mpr_result(self, process_statuses):
|
def _get_mpr_result(self, process_statuses):
|
||||||
stdout = self._queue_to_list(self._streaming_queue)
|
stdout = self._queue_to_list(self._streaming_queue)
|
||||||
return_values = []
|
return_values = []
|
||||||
for process_status in process_statuses:
|
for process_status in process_statuses.values():
|
||||||
if process_status.return_value is not None:
|
if process_status.return_value is not None:
|
||||||
return_values.append(process_status.return_value)
|
return_values.append(process_status.return_value)
|
||||||
return MultiProcessRunnerResult(stdout=stdout, return_value=return_values)
|
return MultiProcessRunnerResult(stdout=stdout, return_value=return_values)
|
||||||
|
|
||||||
def terminate(self, task_type, task_id):
|
def terminate(self, task_type, task_id):
|
||||||
"""Terminates the process with `task_type` and `task_id`."""
|
"""Terminates the process with `task_type` and `task_id`.
|
||||||
|
|
||||||
|
If auto_retart=True, the terminated task will be restarted unless the chief
|
||||||
|
has already exited with zero exit code.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_type: the task type.
|
||||||
|
task_id: the task id.
|
||||||
|
|
||||||
|
"""
|
||||||
|
with self._process_lock:
|
||||||
p = self._processes.get((task_type, task_id), None)
|
p = self._processes.get((task_type, task_id), None)
|
||||||
if p is None:
|
if p is None:
|
||||||
raise ValueError('{}-{} does not exist'.format(task_type, task_id))
|
raise ValueError('{}-{} does not exist'.format(task_type, task_id))
|
||||||
|
self._terminated.add((task_type, task_id))
|
||||||
# TODO(crccw): change to use Process.terminate() as well.
|
# TODO(crccw): change to use Process.terminate() as well.
|
||||||
self._parent_to_sub_queue.put('terminate {} {}'.format(task_type, task_id))
|
self._parent_to_sub_queue.put('terminate {} {}'.format(
|
||||||
|
task_type, task_id))
|
||||||
p.join()
|
p.join()
|
||||||
|
|
||||||
def terminate_all(self, sig=None):
|
def _terminate_all(self, sig=None):
|
||||||
"""Terminates all subprocesses."""
|
"""Terminates all subprocesses.
|
||||||
|
|
||||||
|
The caller is required to hold self._process_lock.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sig: the signal used to terminate the process. The default is SIGKILL.
|
||||||
|
"""
|
||||||
|
|
||||||
# Use SIGKILL as default. In systems where that's unavailable such as
|
# Use SIGKILL as default. In systems where that's unavailable such as
|
||||||
# windows, use SIGTERM.
|
# windows, use SIGTERM.
|
||||||
sig = sig or getattr(signal, 'SIGKILL', signal.SIGTERM)
|
sig = sig or getattr(signal, 'SIGKILL', signal.SIGTERM)
|
||||||
for (task_type, task_id), p in self._processes.items():
|
for (task_type, task_id), p in self._processes.items():
|
||||||
|
if p.exitcode is not None:
|
||||||
|
continue
|
||||||
try:
|
try:
|
||||||
os.kill(p.pid, sig)
|
os.kill(p.pid, sig)
|
||||||
|
self._terminated.add((task_type, task_id))
|
||||||
logging.info('%s-%d terminated with signal %r.', task_type, task_id,
|
logging.info('%s-%d terminated with signal %r.', task_type, task_id,
|
||||||
sig)
|
sig)
|
||||||
except ProcessLookupError:
|
except ProcessLookupError:
|
||||||
logging.info('Attempting to kill %s-%d but it does not exist.',
|
logging.info('Attempting to kill %s-%d but it does not exist.',
|
||||||
task_type, task_id)
|
task_type, task_id)
|
||||||
self._all_forced_terminated = True
|
|
||||||
|
def terminate_all(self, sig=None):
|
||||||
|
"""Terminates all subprocesses."""
|
||||||
|
with self._process_lock:
|
||||||
|
self._terminate_all(sig)
|
||||||
|
|
||||||
|
|
||||||
class _Process(multi_process_lib.Process):
|
class _Process(multi_process_lib.Process):
|
||||||
@ -625,11 +735,13 @@ class _ProcFunc(object):
|
|||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
self._resources.process_status_queue.put(
|
self._resources.process_status_queue.put(
|
||||||
_ProcessStatusInfo(
|
_ProcessStatusInfo(
|
||||||
|
task_type=task_type,
|
||||||
|
task_id=task_id,
|
||||||
is_successful=True,
|
is_successful=True,
|
||||||
exc_info=None,
|
exc_info=None,
|
||||||
return_value=None))
|
return_value=None))
|
||||||
# `os._exit(0)` is used to more reliably terminate a subprocess.
|
# `os._exit(1)` is used to more reliably terminate a subprocess.
|
||||||
os._exit(0) # pylint: disable=protected-access
|
os._exit(1) # pylint: disable=protected-access
|
||||||
|
|
||||||
def _close_streaming(self):
|
def _close_streaming(self):
|
||||||
"""Close stdout, stderr and streaming pipe.
|
"""Close stdout, stderr and streaming pipe.
|
||||||
@ -685,7 +797,8 @@ class _ProcFunc(object):
|
|||||||
v2_compat.enable_v2_behavior()
|
v2_compat.enable_v2_behavior()
|
||||||
|
|
||||||
with self._runtime_mode(test_env.executing_eagerly):
|
with self._runtime_mode(test_env.executing_eagerly):
|
||||||
info = _run_contained(proc_func, args, kwargs)
|
info = _run_contained(test_env.task_type, test_env.task_id, proc_func,
|
||||||
|
args, kwargs)
|
||||||
self._resources.process_status_queue.put(info)
|
self._resources.process_status_queue.put(info)
|
||||||
|
|
||||||
# Re-raise the exception in addition to reporting it to the parent
|
# Re-raise the exception in addition to reporting it to the parent
|
||||||
@ -774,7 +887,7 @@ class MultiProcessPoolRunner(object):
|
|||||||
task_type,
|
task_type,
|
||||||
task_id,
|
task_id,
|
||||||
proc_func=_pool_runner_worker,
|
proc_func=_pool_runner_worker,
|
||||||
args=(initializer, conn2))
|
args=(task_type, task_id, initializer, conn2))
|
||||||
|
|
||||||
def run(self, proc_func, args=None, kwargs=None):
|
def run(self, proc_func, args=None, kwargs=None):
|
||||||
"""Runs `proc_func` with `args` and `kwargs` on all jobs.
|
"""Runs `proc_func` with `args` and `kwargs` on all jobs.
|
||||||
@ -819,7 +932,7 @@ class MultiProcessPoolRunner(object):
|
|||||||
return return_values
|
return return_values
|
||||||
|
|
||||||
|
|
||||||
def _pool_runner_worker(initializer, conn):
|
def _pool_runner_worker(task_type, task_id, initializer, conn):
|
||||||
"""Function that runs on the workers in a pool.
|
"""Function that runs on the workers in a pool.
|
||||||
|
|
||||||
It listens for callables to run and returns the result until `conn` is closed.
|
It listens for callables to run and returns the result until `conn` is closed.
|
||||||
@ -827,8 +940,10 @@ def _pool_runner_worker(initializer, conn):
|
|||||||
`conn`.
|
`conn`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
initializer: A callable to execute during startup.
|
task_type: the task type.
|
||||||
conn: A multiprocessing.Connection object to listen for tasks and send
|
task_id: the task index.
|
||||||
|
initializer: a callable to execute during startup.
|
||||||
|
conn: a multiprocessing.Connection object to listen for tasks and send
|
||||||
results.
|
results.
|
||||||
"""
|
"""
|
||||||
if initializer:
|
if initializer:
|
||||||
@ -840,22 +955,24 @@ def _pool_runner_worker(initializer, conn):
|
|||||||
except EOFError:
|
except EOFError:
|
||||||
break
|
break
|
||||||
proc_func = dill.loads(proc_func)
|
proc_func = dill.loads(proc_func)
|
||||||
info = _run_contained(proc_func, args, kwargs)
|
info = _run_contained(task_type, task_id, proc_func, args, kwargs)
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
sys.stderr.flush()
|
sys.stderr.flush()
|
||||||
conn.send(info)
|
conn.send(info)
|
||||||
|
|
||||||
|
|
||||||
def _run_contained(proc_func, args, kwargs):
|
def _run_contained(task_type, task_id, proc_func, args, kwargs):
|
||||||
"""Runs `proc_func` with `args` and `kwargs`.
|
"""Runs `proc_func` with `args` and `kwargs`.
|
||||||
|
|
||||||
The function returns _ProcessStatusInfo which captures the return value and
|
The function returns _ProcessStatusInfo which captures the return value and
|
||||||
the exception.
|
the exception.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
proc_func: The function to be run.
|
task_type: the task type.
|
||||||
args: Optional positional arguments to be supplied in `proc_func`.
|
task_id: the task index.
|
||||||
kwargs: Optional keyword arguments to be supplied in `proc_func`.
|
proc_func: the function to be run.
|
||||||
|
args: optional positional arguments to be supplied in `proc_func`.
|
||||||
|
kwargs: optional keyword arguments to be supplied in `proc_func`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
a _ProcessStatusInfo.
|
a _ProcessStatusInfo.
|
||||||
@ -868,6 +985,8 @@ def _run_contained(proc_func, args, kwargs):
|
|||||||
return_value = proc_func(*args, **kwargs)
|
return_value = proc_func(*args, **kwargs)
|
||||||
is_successful = True
|
is_successful = True
|
||||||
return _ProcessStatusInfo(
|
return _ProcessStatusInfo(
|
||||||
|
task_type=task_type,
|
||||||
|
task_id=task_id,
|
||||||
is_successful=is_successful,
|
is_successful=is_successful,
|
||||||
exc_info=exc_info,
|
exc_info=exc_info,
|
||||||
return_value=return_value)
|
return_value=return_value)
|
||||||
@ -877,6 +996,8 @@ def _run_contained(proc_func, args, kwargs):
|
|||||||
except Exception: # pylint: disable=broad-except
|
except Exception: # pylint: disable=broad-except
|
||||||
exc_info = sys.exc_info()
|
exc_info = sys.exc_info()
|
||||||
return _ProcessStatusInfo(
|
return _ProcessStatusInfo(
|
||||||
|
task_type=task_type,
|
||||||
|
task_id=task_id,
|
||||||
is_successful=is_successful,
|
is_successful=is_successful,
|
||||||
exc_info=exc_info,
|
exc_info=exc_info,
|
||||||
return_value=return_value)
|
return_value=return_value)
|
||||||
|
|||||||
@ -156,11 +156,8 @@ class MultiProcessRunnerTest(test.TestCase):
|
|||||||
mpr.start()
|
mpr.start()
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
mpr.terminate('worker', 0)
|
mpr.terminate('worker', 0)
|
||||||
with self.assertRaises(
|
|
||||||
multi_process_runner.UnexpectedSubprocessExitError) as cm:
|
|
||||||
mpr.join()
|
|
||||||
|
|
||||||
std_stream_results = cm.exception.mpr_result.stdout
|
std_stream_results = mpr.join().stdout
|
||||||
|
|
||||||
# Worker 0 is terminated in the middle, so it should not have iteration 9
|
# Worker 0 is terminated in the middle, so it should not have iteration 9
|
||||||
# printed.
|
# printed.
|
||||||
@ -388,6 +385,99 @@ class MultiProcessRunnerTest(test.TestCase):
|
|||||||
'Subprocess worker-0 exited with exit code 10'):
|
'Subprocess worker-0 exited with exit code 10'):
|
||||||
mpr.join()
|
mpr.join()
|
||||||
|
|
||||||
|
def test_auto_restart(self):
|
||||||
|
|
||||||
|
def proc_func(counter):
|
||||||
|
counter.value += 1
|
||||||
|
if counter.value == 1:
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
manager = multi_process_runner.manager()
|
||||||
|
counter = manager.Value(int, 0)
|
||||||
|
mpr = multi_process_runner.MultiProcessRunner(
|
||||||
|
proc_func,
|
||||||
|
multi_worker_test_base.create_cluster_spec(num_workers=1),
|
||||||
|
args=(counter,),
|
||||||
|
auto_restart=True)
|
||||||
|
mpr.start()
|
||||||
|
mpr.join()
|
||||||
|
self.assertEqual(counter.value, 2)
|
||||||
|
|
||||||
|
def test_auto_restart_and_timeout(self):
|
||||||
|
|
||||||
|
def proc_func():
|
||||||
|
time.sleep(1)
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
mpr = multi_process_runner.MultiProcessRunner(
|
||||||
|
proc_func,
|
||||||
|
multi_worker_test_base.create_cluster_spec(num_workers=1),
|
||||||
|
auto_restart=True)
|
||||||
|
mpr.start()
|
||||||
|
with self.assertRaises(multi_process_runner.SubprocessTimeoutError):
|
||||||
|
mpr.join(timeout=10)
|
||||||
|
|
||||||
|
def test_auto_restart_and_chief(self):
|
||||||
|
# If the chief has exited with zero exit code, auto restart should stop
|
||||||
|
# restarting other tasks even if they fail.
|
||||||
|
|
||||||
|
def proc_func():
|
||||||
|
time.sleep(1)
|
||||||
|
if multi_worker_test_base.get_task_type() != 'chief':
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
manager = multi_process_runner.manager()
|
||||||
|
mpr = multi_process_runner.MultiProcessRunner(
|
||||||
|
proc_func,
|
||||||
|
multi_worker_test_base.create_cluster_spec(
|
||||||
|
has_chief=True, num_workers=1),
|
||||||
|
auto_restart=True)
|
||||||
|
mpr.start()
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
mpr.join(timeout=10)
|
||||||
|
|
||||||
|
def test_auto_restart_failure_immediate_after_restart(self):
|
||||||
|
# Test the case when worker-0 fails immediately after worker-1 restarts.
|
||||||
|
|
||||||
|
def proc_func():
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
mpr = multi_process_runner.MultiProcessRunner(
|
||||||
|
proc_func,
|
||||||
|
multi_worker_test_base.create_cluster_spec(
|
||||||
|
has_chief=False, num_workers=2),
|
||||||
|
auto_restart=True)
|
||||||
|
mpr.start()
|
||||||
|
pid = mpr.get_process_id('worker', 1)
|
||||||
|
mpr.terminate('worker', 1)
|
||||||
|
while mpr.get_process_id('worker', 1) == pid:
|
||||||
|
time.sleep(0.1)
|
||||||
|
mpr.terminate('worker', 0)
|
||||||
|
mpr.join(timeout=20)
|
||||||
|
|
||||||
|
def test_auto_restart_terminate(self):
|
||||||
|
# Tasks terminated by the user should also be restarted.
|
||||||
|
|
||||||
|
def proc_func(counter):
|
||||||
|
counter.value += 1
|
||||||
|
if counter.value == 1:
|
||||||
|
time.sleep(100)
|
||||||
|
|
||||||
|
manager = multi_process_runner.manager()
|
||||||
|
counter = manager.Value(int, 0)
|
||||||
|
|
||||||
|
mpr = multi_process_runner.MultiProcessRunner(
|
||||||
|
proc_func,
|
||||||
|
multi_worker_test_base.create_cluster_spec(
|
||||||
|
has_chief=False, num_workers=1),
|
||||||
|
args=(counter,),
|
||||||
|
auto_restart=True)
|
||||||
|
mpr.start()
|
||||||
|
time.sleep(3)
|
||||||
|
mpr.terminate('worker', 0)
|
||||||
|
mpr.join(timeout=20)
|
||||||
|
self.assertEqual(counter.value, 2)
|
||||||
|
|
||||||
|
|
||||||
class MultiProcessPoolRunnerTest(test.TestCase):
|
class MultiProcessPoolRunnerTest(test.TestCase):
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user