Add auto_restart to multi_process_runner
This helps creating fault tolerance test cases. MWMS currently requires an external system which brings back tasks that are down, otherwise the remaining workers may hang forever. Ideally the remaining workers should error, which is what I'm working on. But it's beneficial to have test cases reflecting the current behavior since in many deployment, we do have a cluster management system that does the restart (e.g. k8s). This also changes the behavior of dependence_on_chief. We used to terminate the cluster if chief exits when join() is called. Now with a watchdog thread, that happens immediately after the chief's terminate. PiperOrigin-RevId: 324094287 Change-Id: If183d839168d71a2ffe3de10d1410e2116ea1f80
This commit is contained in:
parent
c111987612
commit
56bb1ccfaa
@ -67,7 +67,8 @@ except ImportError:
|
||||
# exception stack trace info is stored in exc_info to pass on to parent process
|
||||
# to be re-raised.
|
||||
_ProcessStatusInfo = collections.namedtuple(
|
||||
'_ProcessStatusInfo', ['is_successful', 'exc_info', 'return_value'])
|
||||
'_ProcessStatusInfo',
|
||||
['task_type', 'task_id', 'is_successful', 'exc_info', 'return_value'])
|
||||
|
||||
# Information returned from a successful MultiProcessRunner run.
|
||||
MultiProcessRunnerResult = collections.namedtuple('MultiProcessRunnerResult',
|
||||
@ -124,6 +125,8 @@ class MultiProcessRunner(object):
|
||||
list_stdout=False,
|
||||
use_dill_for_args=True,
|
||||
daemon=False,
|
||||
dependence_on_chief=True,
|
||||
auto_restart=False,
|
||||
args=None,
|
||||
kwargs=None):
|
||||
"""Creates a multi-process runner.
|
||||
@ -161,6 +164,11 @@ class MultiProcessRunner(object):
|
||||
can pickle more objects, but doesn't work with types in
|
||||
`multiprocessing` library like `Mutex`.
|
||||
daemon: Whether to start processes as daemons.
|
||||
dependence_on_chief: Whether to terminates the cluster if the chief exits.
|
||||
If auto_restart is True, it only terminates the cluster if the chief
|
||||
exits with a zero exit code.
|
||||
auto_restart: Whether to automatically restart processes that exit with
|
||||
non-zero exit code.
|
||||
args: Positional arguments to be sent to functions run on processes.
|
||||
kwargs: Keyword arguments to be sent to functions run on processes.
|
||||
|
||||
@ -190,9 +198,10 @@ class MultiProcessRunner(object):
|
||||
self._stream_stdout = stream_stdout
|
||||
# TODO(rchao): Revisit list_stdout argument to consider other solution.
|
||||
self._list_stdout = list_stdout
|
||||
self._dependence_on_chief = True
|
||||
self._dependence_on_chief = dependence_on_chief
|
||||
self._use_dill_for_args = use_dill_for_args
|
||||
self._daemon = daemon
|
||||
self._auto_restart = auto_restart
|
||||
self._args = args or ()
|
||||
self._kwargs = kwargs or {}
|
||||
|
||||
@ -201,8 +210,15 @@ class MultiProcessRunner(object):
|
||||
self._executing_eagerly = context.executing_eagerly()
|
||||
|
||||
self._joined = False
|
||||
self._process_lock = threading.Lock()
|
||||
# Guarded by self._process_lock.
|
||||
self._processes = {}
|
||||
self._outstanding_subprocess_count = 0
|
||||
# Record which processes are terminated. Due to a bug in Python<3.7,
|
||||
# terminated processes return 255 exit code, which should cause an exception
|
||||
# in join().
|
||||
# https://bugs.python.org/issue30589
|
||||
# Guarded by self._process_lock.
|
||||
self._terminated = set()
|
||||
self._reading_threads = []
|
||||
|
||||
self._manager = manager()
|
||||
@ -215,8 +231,7 @@ class MultiProcessRunner(object):
|
||||
# safe.
|
||||
self._streaming_queue = self._manager.Queue()
|
||||
|
||||
# This flag will be set to True once terminate_all() is called.
|
||||
self._all_forced_terminated = False
|
||||
self._watchdog_thread = None
|
||||
|
||||
def set_args(self, args=None, kwargs=None):
|
||||
self._args = args or self._args
|
||||
@ -281,7 +296,7 @@ class MultiProcessRunner(object):
|
||||
daemon=self._daemon)
|
||||
p.start()
|
||||
self._processes[(task_type, task_id)] = p
|
||||
self._outstanding_subprocess_count += 1
|
||||
self._terminated.discard((task_type, task_id))
|
||||
|
||||
# For each subprocess, we dedicate a thread continuously reading lines
|
||||
# from them.
|
||||
@ -291,17 +306,26 @@ class MultiProcessRunner(object):
|
||||
thread.start()
|
||||
self._reading_threads.append(thread)
|
||||
|
||||
if self._watchdog_thread is None or not self._watchdog_thread.is_alive():
|
||||
self._watchdog_thread = threading.Thread(target=self._process_watchdog)
|
||||
self._watchdog_thread.start()
|
||||
|
||||
def start(self):
|
||||
"""Starts processes, one for each task in `cluster_spec`.
|
||||
|
||||
Note that this is best effort by the applicable multiprocessing library,
|
||||
and it may take up to seconds for a subprocess to be successfully started.
|
||||
"""
|
||||
if self._processes:
|
||||
raise ValueError('MultiProcessRunner already started.')
|
||||
for task_type, addresses in self._cluster_spec.items():
|
||||
for task_id, _ in enumerate(addresses):
|
||||
self._start_subprocess_and_reading_thread(task_type, task_id)
|
||||
with self._process_lock:
|
||||
if self._processes:
|
||||
raise ValueError('MultiProcessRunner already started.')
|
||||
if self._joined:
|
||||
raise ValueError('cannot start new processes after'
|
||||
'MultiProcessRunner.join() is called')
|
||||
|
||||
for task_type, addresses in self._cluster_spec.items():
|
||||
for task_id, _ in enumerate(addresses):
|
||||
self._start_subprocess_and_reading_thread(task_type, task_id)
|
||||
|
||||
# TODO(rchao): Remove the need of using SIGALRM if possible. At this time,
|
||||
# without this the tests become very flaky.
|
||||
@ -353,10 +377,14 @@ class MultiProcessRunner(object):
|
||||
"""
|
||||
if self._processes:
|
||||
raise ValueError('MultiProcessRunner already started.')
|
||||
for task_type, addresses in self._cluster_spec.items():
|
||||
for task_id, _ in enumerate(addresses):
|
||||
if not (task_type == as_task_type and task_id == as_task_id):
|
||||
self._start_subprocess_and_reading_thread(task_type, task_id)
|
||||
with self._process_lock:
|
||||
if self._joined:
|
||||
raise ValueError('cannot start new processes after'
|
||||
'MultiProcessRunner.join() is called')
|
||||
for task_type, addresses in self._cluster_spec.items():
|
||||
for task_id, _ in enumerate(addresses):
|
||||
if not (task_type == as_task_type and task_id == as_task_id):
|
||||
self._start_subprocess_and_reading_thread(task_type, task_id)
|
||||
|
||||
_set_tf_config(as_task_type, as_task_id, self._cluster_spec,
|
||||
self._rpc_layer)
|
||||
@ -392,13 +420,17 @@ class MultiProcessRunner(object):
|
||||
args: Optional positional arguments to be supplied in `proc_func`.
|
||||
kwargs: Optional keyword arguments to be supplied in `proc_func`.
|
||||
"""
|
||||
self._start_subprocess_and_reading_thread(
|
||||
task_type,
|
||||
task_id,
|
||||
cluster_spec=cluster_spec,
|
||||
proc_func=proc_func,
|
||||
args=args or (),
|
||||
kwargs=kwargs or {})
|
||||
with self._process_lock:
|
||||
if self._joined:
|
||||
raise ValueError('cannot start new processes after'
|
||||
'MultiProcessRunner.join() is called')
|
||||
self._start_subprocess_and_reading_thread(
|
||||
task_type,
|
||||
task_id,
|
||||
cluster_spec=cluster_spec,
|
||||
proc_func=proc_func,
|
||||
args=args or (),
|
||||
kwargs=kwargs or {})
|
||||
|
||||
def _queue_to_list(self, queue_to_convert):
|
||||
"""Convert `queue.Queue` to `list`."""
|
||||
@ -411,9 +443,17 @@ class MultiProcessRunner(object):
|
||||
break
|
||||
return list_to_return
|
||||
|
||||
def _get_process_statuses(self):
|
||||
# One worker may have multiple statuses. We only keep the last one.
|
||||
statuses = {}
|
||||
for status in self._queue_to_list(self._process_status_queue):
|
||||
statuses[(status.task_type, status.task_id)] = status
|
||||
return statuses
|
||||
|
||||
def get_process_id(self, task_type, task_id):
|
||||
"""Returns the subprocess id given the task type and task id."""
|
||||
p = self._processes.get((task_type, task_id), None)
|
||||
with self._process_lock:
|
||||
p = self._processes.get((task_type, task_id), None)
|
||||
return p.pid if p else None
|
||||
|
||||
def get_process_exit_code(self, task_type, task_id):
|
||||
@ -430,22 +470,54 @@ class MultiProcessRunner(object):
|
||||
KeyError: If the corresponding subprocess is not found with `task_type`
|
||||
and `task_id`.
|
||||
"""
|
||||
p = self._processes[(task_type, task_id)]
|
||||
with self._process_lock:
|
||||
p = self._processes[(task_type, task_id)]
|
||||
return p.exitcode if p else None
|
||||
|
||||
def _join_or_terminate(self, task_type, task_id, process, timeout):
|
||||
"""Joins a process. If it times out, terminate all procsses."""
|
||||
logging.info('joining %s-%d', task_type, task_id)
|
||||
process.join(timeout)
|
||||
# If exitcode is None, the process aren't terminated and this is a
|
||||
# timeout.
|
||||
if process.exitcode is None:
|
||||
# Force termination to dump worker processes stack trace.
|
||||
self.terminate_all(sig=signal.SIGTERM)
|
||||
process_statuses = self._queue_to_list(self._process_status_queue)
|
||||
raise SubprocessTimeoutError(
|
||||
'%s-%d and possibly more subprocesses timed out.' %
|
||||
(task_type, task_id), self._get_mpr_result(process_statuses))
|
||||
def _process_watchdog(self):
|
||||
"""Simulates a cluster management system.
|
||||
|
||||
- If auto_restart is True, it restarts processes that exit with a non-zero
|
||||
exit code. Note that when join() times out it overrides auto_restart to
|
||||
False.
|
||||
- If dependence_on_chief is True, it terminates all processes once the chief
|
||||
exits. If auto_restart is also True, it only terminates all processes if
|
||||
the chief exit with a zero exit code, otherwise it restarts the chief.
|
||||
|
||||
This runs in self._watchdog_thread.
|
||||
"""
|
||||
while True:
|
||||
time.sleep(1)
|
||||
with self._process_lock:
|
||||
chief = self._processes.get(('chief', 0), None)
|
||||
# Terminate the cluster when _dependence_on_chief is True if either:
|
||||
# - chief has exited with zero exit code.
|
||||
# - chief has exited with non-zero exit code and self._auto_restart is
|
||||
# False.
|
||||
if chief and self._dependence_on_chief and chief.exitcode is not None:
|
||||
if chief.exitcode == 0 or (not self._auto_restart):
|
||||
for p in self._processes.values():
|
||||
# Give other processes a chance to exit on their own.
|
||||
p.join(timeout=3)
|
||||
self._terminate_all()
|
||||
for p in self._processes.values():
|
||||
p.join()
|
||||
return
|
||||
|
||||
# Auto restart failed processes if self._auto_restart is True.
|
||||
if self._auto_restart:
|
||||
has_failure = False
|
||||
for (task_type, task_id), p in self._processes.items():
|
||||
if p.exitcode is not None and p.exitcode != 0:
|
||||
has_failure = True
|
||||
logging.info('Restarting failed %s-%d', task_type, task_id)
|
||||
self._start_subprocess_and_reading_thread(task_type, task_id)
|
||||
if has_failure:
|
||||
continue
|
||||
|
||||
# Exit the thread if all processes have exited at this point.
|
||||
if all(p.exitcode is not None for p in self._processes.values()):
|
||||
return
|
||||
|
||||
def join(self, timeout=_DEFAULT_TIMEOUT_SEC):
|
||||
"""Joins all the processes with timeout.
|
||||
@ -489,41 +561,40 @@ class MultiProcessRunner(object):
|
||||
cases.
|
||||
Exception: if there is an Exception propagated from any subprocess.
|
||||
"""
|
||||
if self._joined:
|
||||
raise ValueError("MultiProcessRunner can't be joined twice.")
|
||||
self._joined = True
|
||||
with self._process_lock:
|
||||
if self._joined:
|
||||
raise ValueError("MultiProcessRunner can't be joined twice.")
|
||||
self._joined = True
|
||||
|
||||
chief = self._processes.get(('chief', 0), None)
|
||||
if self._dependence_on_chief and chief:
|
||||
self._join_or_terminate('chief', 0, chief, timeout)
|
||||
# Give other processes a chance to exit on their own.
|
||||
for p in self._processes.values():
|
||||
p.join(timeout=3)
|
||||
self.terminate_all()
|
||||
else:
|
||||
for (task_type, task_id), p in self._processes.items():
|
||||
self._join_or_terminate(task_type, task_id, p, timeout)
|
||||
self._watchdog_thread.join(timeout)
|
||||
if self._watchdog_thread.is_alive():
|
||||
# Timeout. Force termination to dump worker processes stack trace.
|
||||
with self._process_lock:
|
||||
self._auto_restart = False
|
||||
self.terminate_all(sig=signal.SIGTERM)
|
||||
self._watchdog_thread.join()
|
||||
process_statuses = self._get_process_statuses()
|
||||
raise SubprocessTimeoutError('one or more subprocesses timed out.',
|
||||
self._get_mpr_result(process_statuses))
|
||||
|
||||
for (task_type, task_id), p in self._processes.items():
|
||||
logging.info('%s-%d exit code: %s', task_type, task_id, p.exitcode)
|
||||
|
||||
process_statuses = self._queue_to_list(self._process_status_queue)
|
||||
for process_status in process_statuses:
|
||||
process_statuses = self._get_process_statuses()
|
||||
for process_status in process_statuses.values():
|
||||
assert isinstance(process_status, _ProcessStatusInfo)
|
||||
if not process_status.is_successful:
|
||||
six.reraise(*process_status.exc_info)
|
||||
|
||||
# Checking all the processes that are expected to exit properly.
|
||||
for (task_type, task_id), p in self._processes.items():
|
||||
if self._dependence_on_chief and chief and task_type != 'chief':
|
||||
# If _dependence_on_chief, other processes may have been
|
||||
# forced-terminated, which is expected.
|
||||
continue
|
||||
# Successfully exiting process has exit code 0.
|
||||
if p.exitcode is None or p.exitcode > 0:
|
||||
# Successfully exiting process has exit code 0. We ignore processes that
|
||||
# are terminated.
|
||||
assert p.exitcode is not None
|
||||
if (p.exitcode > 0 and (task_type, task_id) not in self._terminated):
|
||||
raise UnexpectedSubprocessExitError(
|
||||
'Subprocess %s-%d exited with exit code %d. See logs for details.' %
|
||||
(task_type, task_id, p.exitcode),
|
||||
'Subprocess %s-%d exited with exit code %s. See logs for details.'
|
||||
% (task_type, task_id, p.exitcode),
|
||||
self._get_mpr_result(process_statuses))
|
||||
|
||||
logging.info('Joining log reading threads.')
|
||||
@ -539,34 +610,60 @@ class MultiProcessRunner(object):
|
||||
def _get_mpr_result(self, process_statuses):
|
||||
stdout = self._queue_to_list(self._streaming_queue)
|
||||
return_values = []
|
||||
for process_status in process_statuses:
|
||||
for process_status in process_statuses.values():
|
||||
if process_status.return_value is not None:
|
||||
return_values.append(process_status.return_value)
|
||||
return MultiProcessRunnerResult(stdout=stdout, return_value=return_values)
|
||||
|
||||
def terminate(self, task_type, task_id):
|
||||
"""Terminates the process with `task_type` and `task_id`."""
|
||||
p = self._processes.get((task_type, task_id), None)
|
||||
if p is None:
|
||||
raise ValueError('{}-{} does not exist'.format(task_type, task_id))
|
||||
# TODO(crccw): change to use Process.terminate() as well.
|
||||
self._parent_to_sub_queue.put('terminate {} {}'.format(task_type, task_id))
|
||||
p.join()
|
||||
"""Terminates the process with `task_type` and `task_id`.
|
||||
|
||||
If auto_retart=True, the terminated task will be restarted unless the chief
|
||||
has already exited with zero exit code.
|
||||
|
||||
Args:
|
||||
task_type: the task type.
|
||||
task_id: the task id.
|
||||
|
||||
"""
|
||||
with self._process_lock:
|
||||
p = self._processes.get((task_type, task_id), None)
|
||||
if p is None:
|
||||
raise ValueError('{}-{} does not exist'.format(task_type, task_id))
|
||||
self._terminated.add((task_type, task_id))
|
||||
# TODO(crccw): change to use Process.terminate() as well.
|
||||
self._parent_to_sub_queue.put('terminate {} {}'.format(
|
||||
task_type, task_id))
|
||||
p.join()
|
||||
|
||||
def _terminate_all(self, sig=None):
|
||||
"""Terminates all subprocesses.
|
||||
|
||||
The caller is required to hold self._process_lock.
|
||||
|
||||
Args:
|
||||
sig: the signal used to terminate the process. The default is SIGKILL.
|
||||
"""
|
||||
|
||||
def terminate_all(self, sig=None):
|
||||
"""Terminates all subprocesses."""
|
||||
# Use SIGKILL as default. In systems where that's unavailable such as
|
||||
# windows, use SIGTERM.
|
||||
sig = sig or getattr(signal, 'SIGKILL', signal.SIGTERM)
|
||||
for (task_type, task_id), p in self._processes.items():
|
||||
if p.exitcode is not None:
|
||||
continue
|
||||
try:
|
||||
os.kill(p.pid, sig)
|
||||
self._terminated.add((task_type, task_id))
|
||||
logging.info('%s-%d terminated with signal %r.', task_type, task_id,
|
||||
sig)
|
||||
except ProcessLookupError:
|
||||
logging.info('Attempting to kill %s-%d but it does not exist.',
|
||||
task_type, task_id)
|
||||
self._all_forced_terminated = True
|
||||
|
||||
def terminate_all(self, sig=None):
|
||||
"""Terminates all subprocesses."""
|
||||
with self._process_lock:
|
||||
self._terminate_all(sig)
|
||||
|
||||
|
||||
class _Process(multi_process_lib.Process):
|
||||
@ -625,11 +722,13 @@ class _ProcFunc(object):
|
||||
time.sleep(0.1)
|
||||
self._resources.process_status_queue.put(
|
||||
_ProcessStatusInfo(
|
||||
task_type=task_type,
|
||||
task_id=task_id,
|
||||
is_successful=True,
|
||||
exc_info=None,
|
||||
return_value=None))
|
||||
# `os._exit(0)` is used to more reliably terminate a subprocess.
|
||||
os._exit(0) # pylint: disable=protected-access
|
||||
# `os._exit(1)` is used to more reliably terminate a subprocess.
|
||||
os._exit(1) # pylint: disable=protected-access
|
||||
|
||||
def _close_streaming(self):
|
||||
"""Close stdout, stderr and streaming pipe.
|
||||
@ -685,7 +784,8 @@ class _ProcFunc(object):
|
||||
v2_compat.enable_v2_behavior()
|
||||
|
||||
with self._runtime_mode(test_env.executing_eagerly):
|
||||
info = _run_contained(proc_func, args, kwargs)
|
||||
info = _run_contained(test_env.task_type, test_env.task_id, proc_func,
|
||||
args, kwargs)
|
||||
self._resources.process_status_queue.put(info)
|
||||
|
||||
# Re-raise the exception in addition to reporting it to the parent
|
||||
@ -774,7 +874,7 @@ class MultiProcessPoolRunner(object):
|
||||
task_type,
|
||||
task_id,
|
||||
proc_func=_pool_runner_worker,
|
||||
args=(initializer, conn2))
|
||||
args=(task_type, task_id, initializer, conn2))
|
||||
|
||||
def run(self, proc_func, args=None, kwargs=None):
|
||||
"""Runs `proc_func` with `args` and `kwargs` on all jobs.
|
||||
@ -819,7 +919,7 @@ class MultiProcessPoolRunner(object):
|
||||
return return_values
|
||||
|
||||
|
||||
def _pool_runner_worker(initializer, conn):
|
||||
def _pool_runner_worker(task_type, task_id, initializer, conn):
|
||||
"""Function that runs on the workers in a pool.
|
||||
|
||||
It listens for callables to run and returns the result until `conn` is closed.
|
||||
@ -827,8 +927,10 @@ def _pool_runner_worker(initializer, conn):
|
||||
`conn`.
|
||||
|
||||
Args:
|
||||
initializer: A callable to execute during startup.
|
||||
conn: A multiprocessing.Connection object to listen for tasks and send
|
||||
task_type: the task type.
|
||||
task_id: the task index.
|
||||
initializer: a callable to execute during startup.
|
||||
conn: a multiprocessing.Connection object to listen for tasks and send
|
||||
results.
|
||||
"""
|
||||
if initializer:
|
||||
@ -840,22 +942,24 @@ def _pool_runner_worker(initializer, conn):
|
||||
except EOFError:
|
||||
break
|
||||
proc_func = dill.loads(proc_func)
|
||||
info = _run_contained(proc_func, args, kwargs)
|
||||
info = _run_contained(task_type, task_id, proc_func, args, kwargs)
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
conn.send(info)
|
||||
|
||||
|
||||
def _run_contained(proc_func, args, kwargs):
|
||||
def _run_contained(task_type, task_id, proc_func, args, kwargs):
|
||||
"""Runs `proc_func` with `args` and `kwargs`.
|
||||
|
||||
The function returns _ProcessStatusInfo which captures the return value and
|
||||
the exception.
|
||||
|
||||
Args:
|
||||
proc_func: The function to be run.
|
||||
args: Optional positional arguments to be supplied in `proc_func`.
|
||||
kwargs: Optional keyword arguments to be supplied in `proc_func`.
|
||||
task_type: the task type.
|
||||
task_id: the task index.
|
||||
proc_func: the function to be run.
|
||||
args: optional positional arguments to be supplied in `proc_func`.
|
||||
kwargs: optional keyword arguments to be supplied in `proc_func`.
|
||||
|
||||
Returns:
|
||||
a _ProcessStatusInfo.
|
||||
@ -868,6 +972,8 @@ def _run_contained(proc_func, args, kwargs):
|
||||
return_value = proc_func(*args, **kwargs)
|
||||
is_successful = True
|
||||
return _ProcessStatusInfo(
|
||||
task_type=task_type,
|
||||
task_id=task_id,
|
||||
is_successful=is_successful,
|
||||
exc_info=exc_info,
|
||||
return_value=return_value)
|
||||
@ -877,6 +983,8 @@ def _run_contained(proc_func, args, kwargs):
|
||||
except Exception: # pylint: disable=broad-except
|
||||
exc_info = sys.exc_info()
|
||||
return _ProcessStatusInfo(
|
||||
task_type=task_type,
|
||||
task_id=task_id,
|
||||
is_successful=is_successful,
|
||||
exc_info=exc_info,
|
||||
return_value=return_value)
|
||||
|
@ -156,11 +156,8 @@ class MultiProcessRunnerTest(test.TestCase):
|
||||
mpr.start()
|
||||
time.sleep(5)
|
||||
mpr.terminate('worker', 0)
|
||||
with self.assertRaises(
|
||||
multi_process_runner.UnexpectedSubprocessExitError) as cm:
|
||||
mpr.join()
|
||||
|
||||
std_stream_results = cm.exception.mpr_result.stdout
|
||||
std_stream_results = mpr.join().stdout
|
||||
|
||||
# Worker 0 is terminated in the middle, so it should not have iteration 9
|
||||
# printed.
|
||||
@ -388,6 +385,99 @@ class MultiProcessRunnerTest(test.TestCase):
|
||||
'Subprocess worker-0 exited with exit code 10'):
|
||||
mpr.join()
|
||||
|
||||
def test_auto_restart(self):
|
||||
|
||||
def proc_func(counter):
|
||||
counter.value += 1
|
||||
if counter.value == 1:
|
||||
raise ValueError
|
||||
|
||||
manager = multi_process_runner.manager()
|
||||
counter = manager.Value(int, 0)
|
||||
mpr = multi_process_runner.MultiProcessRunner(
|
||||
proc_func,
|
||||
multi_worker_test_base.create_cluster_spec(num_workers=1),
|
||||
args=(counter,),
|
||||
auto_restart=True)
|
||||
mpr.start()
|
||||
mpr.join()
|
||||
self.assertEqual(counter.value, 2)
|
||||
|
||||
def test_auto_restart_and_timeout(self):
|
||||
|
||||
def proc_func():
|
||||
time.sleep(1)
|
||||
raise ValueError
|
||||
|
||||
mpr = multi_process_runner.MultiProcessRunner(
|
||||
proc_func,
|
||||
multi_worker_test_base.create_cluster_spec(num_workers=1),
|
||||
auto_restart=True)
|
||||
mpr.start()
|
||||
with self.assertRaises(multi_process_runner.SubprocessTimeoutError):
|
||||
mpr.join(timeout=10)
|
||||
|
||||
def test_auto_restart_and_chief(self):
|
||||
# If the chief has exited with zero exit code, auto restart should stop
|
||||
# restarting other tasks even if they fail.
|
||||
|
||||
def proc_func():
|
||||
time.sleep(1)
|
||||
if multi_worker_test_base.get_task_type() != 'chief':
|
||||
raise ValueError
|
||||
|
||||
manager = multi_process_runner.manager()
|
||||
mpr = multi_process_runner.MultiProcessRunner(
|
||||
proc_func,
|
||||
multi_worker_test_base.create_cluster_spec(
|
||||
has_chief=True, num_workers=1),
|
||||
auto_restart=True)
|
||||
mpr.start()
|
||||
with self.assertRaises(ValueError):
|
||||
mpr.join(timeout=10)
|
||||
|
||||
def test_auto_restart_failure_immediate_after_restart(self):
|
||||
# Test the case when worker-0 fails immediately after worker-1 restarts.
|
||||
|
||||
def proc_func():
|
||||
time.sleep(5)
|
||||
|
||||
mpr = multi_process_runner.MultiProcessRunner(
|
||||
proc_func,
|
||||
multi_worker_test_base.create_cluster_spec(
|
||||
has_chief=False, num_workers=2),
|
||||
auto_restart=True)
|
||||
mpr.start()
|
||||
pid = mpr.get_process_id('worker', 1)
|
||||
mpr.terminate('worker', 1)
|
||||
while mpr.get_process_id('worker', 1) == pid:
|
||||
time.sleep(0.1)
|
||||
mpr.terminate('worker', 0)
|
||||
mpr.join(timeout=20)
|
||||
|
||||
def test_auto_restart_terminate(self):
|
||||
# Tasks terminated by the user should also be restarted.
|
||||
|
||||
def proc_func(counter):
|
||||
counter.value += 1
|
||||
if counter.value == 1:
|
||||
time.sleep(100)
|
||||
|
||||
manager = multi_process_runner.manager()
|
||||
counter = manager.Value(int, 0)
|
||||
|
||||
mpr = multi_process_runner.MultiProcessRunner(
|
||||
proc_func,
|
||||
multi_worker_test_base.create_cluster_spec(
|
||||
has_chief=False, num_workers=1),
|
||||
args=(counter,),
|
||||
auto_restart=True)
|
||||
mpr.start()
|
||||
time.sleep(3)
|
||||
mpr.terminate('worker', 0)
|
||||
mpr.join(timeout=20)
|
||||
self.assertEqual(counter.value, 2)
|
||||
|
||||
|
||||
class MultiProcessPoolRunnerTest(test.TestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user